Unverified Commit e614aa34 authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous...

[shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)

* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`

* feat: apply `GradientCheckpointConfig` to policy and llama_forward

* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager

* fix: add optional args for `distribute_layer` and `get_stage_index`

* fix: fix changed API calls

* test: update llama tests

* style: polish `GradientCheckpointConfig`

* fix: fix pipeline utils tests
parent df5e9c53
...@@ -109,8 +109,8 @@ class MixtralPolicy(Policy): ...@@ -109,8 +109,8 @@ class MixtralPolicy(Policy):
else: else:
module = self.model.model module = self.model.model
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)} method_replacement = {"forward": partial(new_forward, stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls description=method_replacement, policy=policy, target_key=model_cls
...@@ -129,10 +129,10 @@ class MixtralPolicy(Policy): ...@@ -129,10 +129,10 @@ class MixtralPolicy(Policy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens) held_layers.append(module.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx]) held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
held_layers.append(module.norm) held_layers.append(module.norm)
......
...@@ -26,7 +26,7 @@ from colossalai.cluster import ProcessGroupMesh ...@@ -26,7 +26,7 @@ from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy from colossalai.shardformer.policies.base_policy import Policy
from colossalai.tensor.d_tensor.api import is_distributed_tensor from colossalai.tensor.d_tensor.api import is_distributed_tensor
...@@ -930,6 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -930,6 +930,7 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None. custom_policy (Policy, optional): Custom policy for Shardformer. Defaults to None.
pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'. pp_style (str, optional): The style for pipeline parallelism. Defaults to '1f1b'.
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1. num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True. enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
""" """
...@@ -969,6 +970,7 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -969,6 +970,7 @@ class HybridParallelPlugin(PipelinePluginBase):
custom_policy: Policy = None, custom_policy: Policy = None,
pp_style: str = "1f1b", pp_style: str = "1f1b",
num_model_chunks: int = 1, num_model_chunks: int = 1,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True, enable_metadata_cache: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -1043,6 +1045,7 @@ class HybridParallelPlugin(PipelinePluginBase): ...@@ -1043,6 +1045,7 @@ class HybridParallelPlugin(PipelinePluginBase):
enable_sequence_parallelism=enable_sequence_parallelism, enable_sequence_parallelism=enable_sequence_parallelism,
enable_sequence_overlap=enable_sequence_overlap, enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output, parallel_output=parallel_output,
gradient_checkpoint_config=gradient_checkpoint_config,
) )
self.amp_config = dict( self.amp_config = dict(
initial_scale=initial_scale, initial_scale=initial_scale,
......
...@@ -114,12 +114,12 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy): ...@@ -114,12 +114,12 @@ class BloomModelInferPolicy(BloomForCausalLMPolicy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings) held_layers.append(module.word_embeddings)
held_layers.append(module.word_embeddings_layernorm) held_layers.append(module.word_embeddings_layernorm)
held_layers.append(self.model.lm_head) held_layers.append(self.model.lm_head)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx]) held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
held_layers.append(module.ln_f) held_layers.append(module.ln_f)
......
...@@ -69,11 +69,11 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy): ...@@ -69,11 +69,11 @@ class ChatGLM2InferPolicy(ChatGLMModelPolicy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(module.num_layers)
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.embedding) held_layers.append(module.embedding)
held_layers.append(module.output_layer) held_layers.append(module.output_layer)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layers[start_idx:end_idx]) held_layers.extend(module.encoder.layers[start_idx:end_idx])
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
if module.encoder.post_layer_norm: if module.encoder.post_layer_norm:
......
...@@ -194,11 +194,11 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy): ...@@ -194,11 +194,11 @@ class LlamaModelInferPolicy(LlamaForCausalLMPolicy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens) held_layers.append(module.embed_tokens)
held_layers.append(self.model.lm_head) held_layers.append(self.model.lm_head)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx]) held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
held_layers.append(module.norm) held_layers.append(module.norm)
......
import contextlib import contextlib
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
...@@ -29,6 +30,8 @@ class PipelineStageManager: ...@@ -29,6 +30,8 @@ class PipelineStageManager:
) -> None: ) -> None:
assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False" assert enable_interleave or num_model_chunks == 1, "num_model_chunks must be 1 when enable_interleave is False"
self.num_layers_per_stage = None
self.pg_mesh = pg_mesh self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis self.pipeline_axis = pipeline_axis
self.prev_rank: Optional[Tuple[int, ...]] = None self.prev_rank: Optional[Tuple[int, ...]] = None
...@@ -69,6 +72,88 @@ class PipelineStageManager: ...@@ -69,6 +72,88 @@ class PipelineStageManager:
# for shardformer, hold model chunk id # for shardformer, hold model chunk id
self.model_chunk_id: Optional[int] = None self.model_chunk_id: Optional[int] = None
@property
def control_distribute_layers(self) -> bool:
return self.num_layers_per_stage is not None
def set_distribution_config(self, num_model_layers: int, num_layers_per_stage: List[int]) -> None:
"""Set the distribution configuration.
This allows user to customize the number of layers for each stage.
Args:
num_model_layers (int): Number of layers in the model.
num_layers_per_stage (List[int]): Number of layers for each stage.
"""
assert all([0 < num_layers < num_model_layers for num_layers in num_layers_per_stage])
assert sum(num_layers_per_stage) == num_model_layers
assert len(num_layers_per_stage) == self.num_stages * (self.num_model_chunks if self.is_interleave else 1)
self.num_model_layers = num_model_layers
self.num_layers_per_stage = num_layers_per_stage
def distribute_layers(
self, num_layers: int, num_stages: Optional[int] = None, num_model_chunks: Optional[int] = None
) -> List[int]:
"""Divide layers into stages"""
num_stages = self.num_stages if num_stages is None else num_stages
num_model_chunks = (
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
)
if self.control_distribute_layers:
assert num_layers == self.num_model_layers
return self.num_layers_per_stage
else:
quotient = num_layers // (num_stages * num_model_chunks)
remainder = num_layers % (num_stages * num_model_chunks)
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages * num_model_chunks
# deal with the rest layers
if remainder > 0:
start_position = (num_stages * num_model_chunks) // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage
def get_stage_index(
self,
layers_per_stage: List[int],
stage: Optional[int] = None,
num_model_chunks: Optional[int] = None,
num_stages: Optional[int] = None,
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
"""
Get the start index and end index of layers for each stage.
Args:
layers_per_stage (List[int]): number of layers for each stage
stage (int): the stage index
num_stages (int): number of stages
num_model_chunks (int): number of model chunks
Returns:
- Tuple[int, int]: the start index and end index of this stage
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk
"""
stage = self.stage if stage is None else stage
num_model_chunks = (
(self.num_model_chunks if self.is_interleave else 1) if num_model_chunks is None else num_model_chunks
)
num_stages = self.num_stages if num_stages is None else num_stages
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
stage_indices = []
for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
stage_indices.append([start_idx, end_idx])
return stage_indices[0] if num_model_chunks == 1 else stage_indices
def is_first_stage(self, ignore_chunk: bool = False) -> bool: def is_first_stage(self, ignore_chunk: bool = False) -> bool:
"""Is the current stage the first stage. """Is the current stage the first stage.
......
from .shard import ShardConfig, ShardFormer from .shard import GradientCheckpointConfig, ModelSharder, PipelineGradientCheckpointConfig, ShardConfig, ShardFormer
...@@ -138,13 +138,25 @@ class LlamaPipelineForwards: ...@@ -138,13 +138,25 @@ class LlamaPipelineForwards:
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
start_idx, end_idx = stage_index[0], stage_index[1] start_idx, end_idx = stage_index[0], stage_index[1]
num_ckpt_layers = 0
if self.gradient_checkpointing and self.training:
num_ckpt_layers = end_idx - start_idx
# TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
if shard_config.gradient_checkpoint_config is not None:
num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
stage=stage_manager.stage,
num_layers=end_idx - start_idx,
model_chunk_id=stage_manager.model_chunk_id if stage_manager.is_interleave else 0,
)
assert num_ckpt_layers <= end_idx - start_idx
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx): for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if self.gradient_checkpointing and self.training: if idx - start_idx < num_ckpt_layers:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -2,9 +2,8 @@ ...@@ -2,9 +2,8 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch.nn as nn import torch.nn as nn
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
...@@ -196,49 +195,3 @@ class Policy(ABC): ...@@ -196,49 +195,3 @@ class Policy(ABC):
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}] List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
""" """
return [] return []
def distribute_layers(self, num_layers: int, num_stages: int) -> List[int]:
"""Divide layers into stages"""
quotient = num_layers // num_stages
remainder = num_layers % num_stages
# calculate the num_layers per stage
layers_per_stage = [quotient] * num_stages
# deal with the rest layers
if remainder > 0:
start_position = num_stages // 2 - remainder // 2
for i in range(start_position, start_position + remainder):
layers_per_stage[i] += 1
return layers_per_stage
def get_stage_index(
self,
layers_per_stage: List[int],
stage: int,
num_model_chunks: int = 1,
num_stages: int = 0,
) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
"""
Get the start index and end index of layers for each stage.
Args:
layers_per_stage (List[int]): number of layers for each stage
stage (int): the stage index
num_stages (int): number of stages
num_model_chunks (int): number of model chunks
Returns:
- Tuple[int, int]: the start index and end index of this stage
- List[Tuple[int, int]]: the start index and end index of this stage for each model chunk
"""
num_layers_per_stage_accumulated = np.insert(np.cumsum(layers_per_stage), 0, 0)
stage_indices = []
for model_chunk in range(num_model_chunks):
start_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages]
end_idx = num_layers_per_stage_accumulated[stage + model_chunk * num_stages + 1]
stage_indices.append([start_idx, end_idx])
return stage_indices[0] if num_model_chunks == 1 else stage_indices
...@@ -279,16 +279,8 @@ class BertPolicy(Policy): ...@@ -279,16 +279,8 @@ class BertPolicy(Policy):
module = self.model.bert module = self.model.bert
if stage_manager.is_interleave: if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers( layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
len(module.encoder.layer), stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
stage_manager.num_stages * stage_manager.num_model_chunks,
)
stage_manager.stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, new_forward,
...@@ -298,8 +290,8 @@ class BertPolicy(Policy): ...@@ -298,8 +290,8 @@ class BertPolicy(Policy):
} }
else: else:
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, new_forward,
...@@ -324,16 +316,8 @@ class BertPolicy(Policy): ...@@ -324,16 +316,8 @@ class BertPolicy(Policy):
held_layers = [] held_layers = []
if stage_manager.is_interleave: if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers( layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
len(module.encoder.layer), stage_indices = stage_manager.get_stage_index(layers_per_stage)
stage_manager.num_stages * stage_manager.num_model_chunks,
)
stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
if stage_manager.is_first_stage(ignore_chunk=True): if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embeddings) held_layers.append(module.embeddings)
for start_idx, end_idx in stage_indices: for start_idx, end_idx in stage_indices:
...@@ -342,10 +326,10 @@ class BertPolicy(Policy): ...@@ -342,10 +326,10 @@ class BertPolicy(Policy):
held_layers.append(module.pooler) held_layers.append(module.pooler)
else: else:
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.embeddings) held_layers.append(module.embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layer[start_idx:end_idx]) held_layers.extend(module.encoder.layer[start_idx:end_idx])
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
held_layers.append(module.pooler) held_layers.append(module.pooler)
......
...@@ -203,8 +203,8 @@ class BloomPolicy(Policy): ...@@ -203,8 +203,8 @@ class BloomPolicy(Policy):
else: else:
module = self.model.transformer module = self.model.transformer
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.h))
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
...@@ -226,11 +226,11 @@ class BloomPolicy(Policy): ...@@ -226,11 +226,11 @@ class BloomPolicy(Policy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings) held_layers.append(module.word_embeddings)
held_layers.append(module.word_embeddings_layernorm) held_layers.append(module.word_embeddings_layernorm)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx]) held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
held_layers.append(module.ln_f) held_layers.append(module.ln_f)
......
...@@ -179,10 +179,10 @@ class ChatGLMPolicy(Policy): ...@@ -179,10 +179,10 @@ class ChatGLMPolicy(Policy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(module.num_layers)
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.embedding) held_layers.append(module.embedding)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layers[start_idx:end_idx]) held_layers.extend(module.encoder.layers[start_idx:end_idx])
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
if module.encoder.post_layer_norm: if module.encoder.post_layer_norm:
...@@ -204,8 +204,8 @@ class ChatGLMPolicy(Policy): ...@@ -204,8 +204,8 @@ class ChatGLMPolicy(Policy):
else: else:
module = self.model.transformer module = self.model.transformer
layers_per_stage = self.distribute_layers(module.num_layers, stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(module.num_layers)
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
......
...@@ -161,8 +161,8 @@ class FalconPolicy(Policy): ...@@ -161,8 +161,8 @@ class FalconPolicy(Policy):
else: else:
module = self.model.transformer module = self.model.transformer
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.h))
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
...@@ -181,10 +181,10 @@ class FalconPolicy(Policy): ...@@ -181,10 +181,10 @@ class FalconPolicy(Policy):
module = self.model.transformer module = self.model.transformer
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.word_embeddings) held_layers.append(module.word_embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx]) held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
held_layers.append(module.ln_f) held_layers.append(module.ln_f)
......
...@@ -185,15 +185,8 @@ class GPT2Policy(Policy): ...@@ -185,15 +185,8 @@ class GPT2Policy(Policy):
held_layers = [] held_layers = []
if stage_manager.is_interleave: if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers( layers_per_stage = stage_manager.distribute_layers(len(module.h))
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks stage_indices = stage_manager.get_stage_index(layers_per_stage)
)
stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
if stage_manager.is_first_stage(ignore_chunk=True): if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.wte) held_layers.append(module.wte)
held_layers.append(module.wpe) held_layers.append(module.wpe)
...@@ -203,12 +196,12 @@ class GPT2Policy(Policy): ...@@ -203,12 +196,12 @@ class GPT2Policy(Policy):
if stage_manager.is_last_stage(ignore_chunk=True): if stage_manager.is_last_stage(ignore_chunk=True):
held_layers.append(module.ln_f) held_layers.append(module.ln_f)
else: else:
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.wte) held_layers.append(module.wte)
held_layers.append(module.wpe) held_layers.append(module.wpe)
held_layers.append(module.drop) held_layers.append(module.drop)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx]) held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
held_layers.append(module.ln_f) held_layers.append(module.ln_f)
...@@ -226,15 +219,8 @@ class GPT2Policy(Policy): ...@@ -226,15 +219,8 @@ class GPT2Policy(Policy):
module = self.model.transformer module = self.model.transformer
if stage_manager.is_interleave: if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers( layers_per_stage = stage_manager.distribute_layers(len(module.h))
len(module.h), stage_manager.num_stages * stage_manager.num_model_chunks stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
)
stage_manager.stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, new_forward,
...@@ -243,8 +229,8 @@ class GPT2Policy(Policy): ...@@ -243,8 +229,8 @@ class GPT2Policy(Policy):
) )
} }
else: else:
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.h))
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, new_forward,
......
...@@ -179,11 +179,11 @@ class GPTJPolicy(Policy): ...@@ -179,11 +179,11 @@ class GPTJPolicy(Policy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.h))
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.wte) held_layers.append(module.wte)
held_layers.append(module.drop) held_layers.append(module.drop)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.h[start_idx:end_idx]) held_layers.extend(module.h[start_idx:end_idx])
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
held_layers.append(module.ln_f) held_layers.append(module.ln_f)
...@@ -200,8 +200,8 @@ class GPTJPolicy(Policy): ...@@ -200,8 +200,8 @@ class GPTJPolicy(Policy):
else: else:
module = self.model.transformer module = self.model.transformer
layers_per_stage = self.distribute_layers(len(module.h), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.h))
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, new_forward,
......
...@@ -164,30 +164,20 @@ class LlamaPolicy(Policy): ...@@ -164,30 +164,20 @@ class LlamaPolicy(Policy):
module = self.model.model module = self.model.model
if stage_manager.is_interleave: if stage_manager.is_interleave:
layers_per_stage = self.distribute_layers( layers_per_stage = stage_manager.distribute_layers(len(module.layers))
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
)
stage_manager.stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
method_replacement = { method_replacement = {
"forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config) "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
} }
else: else:
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
) )
} }
self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls
)
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls) self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
...@@ -204,15 +194,8 @@ class LlamaPolicy(Policy): ...@@ -204,15 +194,8 @@ class LlamaPolicy(Policy):
held_layers = [] held_layers = []
if stage_manager.is_interleave: if stage_manager.is_interleave:
assert stage_manager.num_model_chunks is not None assert stage_manager.num_model_chunks is not None
layers_per_stage = self.distribute_layers( layers_per_stage = stage_manager.distribute_layers(len(module.layers))
len(module.layers), stage_manager.num_stages * stage_manager.num_model_chunks stage_indices = stage_manager.get_stage_index(layers_per_stage)
)
stage_indices = self.get_stage_index(
layers_per_stage,
stage_manager.stage,
num_model_chunks=stage_manager.num_model_chunks,
num_stages=stage_manager.num_stages,
)
if stage_manager.is_first_stage(ignore_chunk=True): if stage_manager.is_first_stage(ignore_chunk=True):
held_layers.append(module.embed_tokens) held_layers.append(module.embed_tokens)
for start_idx, end_idx in stage_indices: for start_idx, end_idx in stage_indices:
...@@ -221,10 +204,10 @@ class LlamaPolicy(Policy): ...@@ -221,10 +204,10 @@ class LlamaPolicy(Policy):
held_layers.append(module.norm) held_layers.append(module.norm)
else: else:
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens) held_layers.append(module.embed_tokens)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx]) held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
held_layers.append(module.norm) held_layers.append(module.norm)
......
...@@ -186,12 +186,12 @@ class OPTPolicy(Policy): ...@@ -186,12 +186,12 @@ class OPTPolicy(Policy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.layers))
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.embed_tokens) held_layers.append(module.embed_tokens)
held_layers.append(module.embed_positions) held_layers.append(module.embed_positions)
held_layers.append(module.project_in) held_layers.append(module.project_in)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.layers[start_idx:end_idx]) held_layers.extend(module.layers[start_idx:end_idx])
if stage_manager.is_last_stage(): if stage_manager.is_last_stage():
held_layers.append(module.final_layer_norm) held_layers.append(module.final_layer_norm)
...@@ -208,8 +208,8 @@ class OPTPolicy(Policy): ...@@ -208,8 +208,8 @@ class OPTPolicy(Policy):
else: else:
module = self.model.model.decoder module = self.model.model.decoder
layers_per_stage = self.distribute_layers(len(module.layers), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.layers))
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = { method_replacement = {
"forward": partial( "forward": partial(
new_forward, new_forward,
......
...@@ -251,6 +251,8 @@ class T5BasePolicy(Policy): ...@@ -251,6 +251,8 @@ class T5BasePolicy(Policy):
Return the layer distribution as a list and the starting stage of decoder. Return the layer distribution as a list and the starting stage of decoder.
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers. If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
""" """
stage_manager = self.pipeline_stage_manager
assert stage_manager is not None, "Pipeline stage manager is not set."
# number of encoder layers must be a positive integer # number of encoder layers must be a positive integer
if num_encoder_layers <= 0: if num_encoder_layers <= 0:
...@@ -262,7 +264,7 @@ class T5BasePolicy(Policy): ...@@ -262,7 +264,7 @@ class T5BasePolicy(Policy):
# in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist # in the case of T5EncoderModel, set decoder starting stage to num_stages since it doesn't exist
if num_decoder_layers == 0: if num_decoder_layers == 0:
return self.distribute_layers(num_encoder_layers, num_stages), num_stages return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages
# the number of stages distributed between encoder and decoder is optimized in this way: # the number of stages distributed between encoder and decoder is optimized in this way:
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
...@@ -273,21 +275,26 @@ class T5BasePolicy(Policy): ...@@ -273,21 +275,26 @@ class T5BasePolicy(Policy):
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages) encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages)
decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages) decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages)
return encoder_distribution + decoder_distribution, num_encoder_stages return encoder_distribution + decoder_distribution, num_encoder_stages
def get_t5_stage_index( def get_t5_stage_index(
self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
) -> Tuple[bool, int, int]: ) -> Tuple[int, int]:
""" """
Input the distribution of layers among stages, the current stage and the first stage of decoder. Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder Return the starting/ending idx of layers in encoder/decoder
""" """
stage_manager = self.pipeline_stage_manager
assert stage_manager is not None, "Pipeline stage manager is not set."
if stage < decoder_starting_stage: if stage < decoder_starting_stage:
return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
else: else:
return self.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage) return stage_manager.get_stage_index(
layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage
)
def get_held_layers(self) -> List[nn.Module]: def get_held_layers(self) -> List[nn.Module]:
"""Get pipeline layers for current stage.""" """Get pipeline layers for current stage."""
......
...@@ -134,10 +134,10 @@ class ViTPolicy(Policy): ...@@ -134,10 +134,10 @@ class ViTPolicy(Policy):
stage_manager = self.pipeline_stage_manager stage_manager = self.pipeline_stage_manager
held_layers = [] held_layers = []
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
if stage_manager.is_first_stage(): if stage_manager.is_first_stage():
held_layers.append(module.embeddings) held_layers.append(module.embeddings)
start_idx, end_idx = self.get_stage_index(layers_per_stage, stage_manager.stage) start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
held_layers.extend(module.encoder.layer[start_idx:end_idx]) held_layers.extend(module.encoder.layer[start_idx:end_idx])
return held_layers return held_layers
...@@ -149,8 +149,8 @@ class ViTPolicy(Policy): ...@@ -149,8 +149,8 @@ class ViTPolicy(Policy):
else: else:
module = self.model.vit module = self.model.vit
layers_per_stage = self.distribute_layers(len(module.encoder.layer), stage_manager.num_stages) layers_per_stage = stage_manager.distribute_layers(len(module.encoder.layer))
stage_index = self.get_stage_index(layers_per_stage, stage_manager.stage) stage_index = stage_manager.get_stage_index(layers_per_stage)
method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)} method_replacement = {"forward": pipeline_forward(stage_manager=stage_manager, stage_index=stage_index)}
self.append_or_create_method_replacement( self.append_or_create_method_replacement(
description=method_replacement, policy=policy, target_key=model_cls description=method_replacement, policy=policy, target_key=model_cls
......
...@@ -300,6 +300,8 @@ class WhisperPolicy(Policy): ...@@ -300,6 +300,8 @@ class WhisperPolicy(Policy):
Return the layer distribution as a list and the starting stage of decoder. Return the layer distribution as a list and the starting stage of decoder.
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers. If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
""" """
stage_manager = self.pipeline_stage_manager
assert stage_manager is not None, "pipeline_stage_manager is None"
# number of encoder layers must be a positive integer # number of encoder layers must be a positive integer
if num_encoder_layers <= 0: if num_encoder_layers <= 0:
...@@ -311,7 +313,7 @@ class WhisperPolicy(Policy): ...@@ -311,7 +313,7 @@ class WhisperPolicy(Policy):
# in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist # in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist
if num_decoder_layers == 0: if num_decoder_layers == 0:
return self.distribute_layers(num_encoder_layers, num_stages), num_stages return stage_manager.distribute_layers(num_encoder_layers, num_stages), num_stages
# the number of stages distributed between encoder and decoder is optimized in this way: # the number of stages distributed between encoder and decoder is optimized in this way:
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages)) # num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
...@@ -322,21 +324,24 @@ class WhisperPolicy(Policy): ...@@ -322,21 +324,24 @@ class WhisperPolicy(Policy):
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1 num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = self.distribute_layers(num_encoder_layers, num_encoder_stages) encoder_distribution = stage_manager.distribute_layers(num_encoder_layers, num_encoder_stages)
decoder_distribution = self.distribute_layers(num_decoder_layers, num_decoder_stages) decoder_distribution = stage_manager.distribute_layers(num_decoder_layers, num_decoder_stages)
return encoder_distribution + decoder_distribution, num_encoder_stages return encoder_distribution + decoder_distribution, num_encoder_stages
def get_whisper_stage_index( def get_whisper_stage_index(
self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int self, layers_per_stage: List[int], stage: int, decoder_starting_stage: int
) -> Tuple[bool, int, int]: ) -> Tuple[int, int]:
""" """
Input the distribution of layers among stages, the current stage and the first stage of decoder. Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder Return the starting/ending idx of layers in encoder/decoder
""" """
stage_manager = self.pipeline_stage_manager
assert stage_manager is not None, "pipeline_stage_manager is None"
if stage < decoder_starting_stage: if stage < decoder_starting_stage:
return self.get_stage_index(layers_per_stage[:decoder_starting_stage], stage) return stage_manager.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
else: else:
return self.get_stage_index( return stage_manager.get_stage_index(
layers_per_stage[decoder_starting_stage:], layers_per_stage[decoder_starting_stage:],
stage - decoder_starting_stage, stage - decoder_starting_stage,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment