Unverified Commit efba0f44 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

Merge pull request #4612 from hpcaitech/feature/shardformer

[shardformer] update hybrid parallel plugin and fix bugs
parents ac178ca5 fae6c92e
import warnings
from functools import partial
from typing import Callable, Dict, List, Tuple
import numpy as np
import torch.nn as nn
from torch import Tensor
import colossalai.shardformer.layer as col_nn
from .._utils import getattr_, setattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
from ..modeling.whisper import (
WhisperPipelineForwards,
get_jit_fused_whisper_decoder_layer_forward,
get_jit_fused_whisper_encoder_layer_forward,
get_whisper_flash_attention_forward,
......@@ -12,7 +19,8 @@ from ..modeling.whisper import (
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
__all__ = [
'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy', 'WhisperForAudioClassification'
'WhisperPolicy', 'WhisperModelPolicy', 'WhisperForConditionalGenerationPolicy',
'WhisperForAudioClassificationPolicy'
]
......@@ -26,7 +34,6 @@ class WhisperPolicy(Policy):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
......@@ -45,6 +52,11 @@ class WhisperPolicy(Policy):
policy = {}
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn(
"Whisper dosen't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[WhisperEncoderLayer] = ModulePolicyDescription(attribute_replacement={
"self_attn.embed_dim":
......@@ -191,20 +203,26 @@ class WhisperPolicy(Policy):
# enable flash attention
if self.shard_config.enable_flash_attention:
policy[WhisperAttention] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_whisper_flash_attention_forward(),
})
},
policy=policy,
target_key=WhisperAttention)
# use jit fused operator
if self.shard_config.enable_jit_fused:
policy[WhisperEncoderLayer] = ModulePolicyDescription(method_replacement={
'forward': get_jit_fused_whisper_encoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
policy[WhisperDecoderLayer] = ModulePolicyDescription(method_replacement={
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_whisper_decoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
})
},
policy=policy,
target_key=WhisperDecoderLayer)
self.append_or_create_method_replacement(description={
'forward': get_jit_fused_whisper_encoder_layer_forward(),
'dropout_add': get_jit_fused_dropout_add_func(),
},
policy=policy,
target_key=WhisperEncoderLayer)
return policy
......@@ -223,6 +241,146 @@ class WhisperPolicy(Policy):
def postprocess(self):
return self.model
@staticmethod
def distribute_whisper_layers(num_encoder_layers: int, num_decoder_layers: int,
num_stages: int) -> Tuple[List[int], int]:
"""
Distribute whisper layers into stages when pipeline parallel is used.
Return the layer distribution as a list and the starting stage of decoder.
If decoder doesn't exist, returned decoder starting stage is set to num_encoder_layers.
"""
# number of encoder layers must be a positive integer
if num_encoder_layers <= 0:
raise ValueError("The number of encoder layers for whisper must be a positive integer.")
# number of layers should be large enough to fill in every stage
if num_encoder_layers + num_decoder_layers < num_stages:
raise ValueError("The total number of layers can't be smaller than number of stages.")
# in the case of whisperEncoderModel, set decoder starting stage to num_stages since it doesn't exist
if num_decoder_layers == 0:
return Policy.distribute_layers(num_encoder_layers, num_stages), num_stages
# the number of stages distributed between encoder and decoder is optmized in this way:
# num_encoder_stages = argmin(abs(num_encoder_layers / encoder_stages - num_decoder_layers / decoder_stages))
# s.t. num_encoder_stages + num_decoder_stages = num_stages, num_encoder_stages >= 1, num_decoder_stages >= 1
def objective(num_encoder_stages):
return abs(num_encoder_layers / num_encoder_stages - num_decoder_layers / (num_stages - num_encoder_stages))
num_encoder_stages = np.argmin([objective(i) for i in range(1, num_stages)]) + 1
num_decoder_stages = num_stages - num_encoder_stages
encoder_distribution = Policy.distribute_layers(num_encoder_layers, num_encoder_stages)
decoder_distribution = Policy.distribute_layers(num_decoder_layers, num_decoder_stages)
return encoder_distribution + decoder_distribution, num_encoder_stages
@staticmethod
def get_whisper_stage_index(layers_per_stage: List[int], stage: int,
decoder_starting_stage: int) -> Tuple[bool, int, int]:
"""
Input the distribution of layers among stages, the current stage and the first stage of decoder.
Return the starting/ending idx of layers in encoder/decoder
"""
if stage < decoder_starting_stage:
return Policy.get_stage_index(layers_per_stage[:decoder_starting_stage], stage)
else:
return Policy.get_stage_index(layers_per_stage[decoder_starting_stage:], stage - decoder_starting_stage)
def get_held_layers(self) -> List[nn.Module]:
assert self.pipeline_stage_manager is not None, "pipeline_stage_manager is None"
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'WhisperModel':
model = self.model
elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
model = self.model.model
else:
model = None
if model:
encoder = self.model.get_encoder()
decoder = self.model.get_decoder()
else:
# whisper for audio classification holds encoder only
encoder = self.model.encoder
decoder = None
num_encoder_layers = len(encoder.layers)
if decoder:
num_decoder_layers = len(decoder.layers)
else:
num_decoder_layers = 0
held_layers = []
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
start_idx, end_idx = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
decoder_starting_stage)
if stage_manager.stage < decoder_starting_stage:
# current stage is in whisper's encoder
if stage_manager.is_first_stage():
held_layers.append(encoder.embed_positions)
held_layers.append(encoder.conv1)
held_layers.append(encoder.conv2)
if stage_manager.stage == decoder_starting_stage - 1:
held_layers.append(encoder.layer_norm)
held_layers.extend(encoder.layers[start_idx:end_idx])
else:
# current stage is in whisper's decoder
# TODO:(Jianghai) We divide encoder and decoder layers into different parts here,
# the case encoder and decoder put in same stage should be add in the future.
if stage_manager.stage == decoder_starting_stage:
held_layers.append(decoder.embed_tokens)
held_layers.append(decoder.embed_positions)
if stage_manager.is_last_stage():
held_layers.append(decoder.layer_norm)
held_layers.extend(decoder.layers[start_idx:end_idx])
return held_layers
def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
"""If under pipeline parallel setting, replacing the original forward method of huggingface
to customized forward method, and add this changing to policy."""
if not self.pipeline_stage_manager:
raise ValueError("set_pipeline_forward method can only be called when pipeline parallel is enabled.")
stage_manager = self.pipeline_stage_manager
if self.model.__class__.__name__ == 'WhisperModel':
model = self.model
elif self.model.__class__.__name__ == 'WhisperForConditionalGeneration':
model = self.model.model
else:
model = None
if model:
encoder = self.model.get_encoder()
decoder = self.model.get_decoder()
else:
encoder = self.model.encoder
decoder = None
num_encoder_layers = len(encoder.layers)
if decoder:
num_decoder_layers = len(decoder.layers)
else:
num_decoder_layers = 0
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
num_encoder_layers, num_decoder_layers, stage_manager.num_stages)
stage_index = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage_manager.stage,
decoder_starting_stage)
method_replacement = {
'forward':
partial(new_forward,
stage_manager=stage_manager,
stage_index=stage_index,
decoder_starting_stage=decoder_starting_stage)
}
self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
# WhisperModel
class WhisperModelPolicy(WhisperPolicy):
......@@ -230,6 +388,24 @@ class WhisperModelPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
from transformers import WhisperModel
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=WhisperModel,
new_forward=WhisperPipelineForwards.whisper_model_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
return super().get_held_layers()
def get_shared_params(self) -> List[Dict[int, Tensor]]:
"no shared params in whisper model"
return []
# WhisperForConditionalGeneration
class WhisperForConditionalGenerationPolicy(WhisperPolicy):
......@@ -238,20 +414,82 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
module_policy = self.add_lm_head_policy(module_policy)
return module_policy
from transformers import WhisperForConditionalGeneration
policy = super().module_policy()
policy = self.add_lm_head_policy(policy)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=WhisperForConditionalGeneration,
new_forward=WhisperPipelineForwards.whisper_for_conditional_generation_forward,
policy=policy)
return policy
def postprocess(self):
binding_map = {"model.decoder.embed_tokens.weight": "proj_out.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)
setattr_(self.model, v, param)
return self.model
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.proj_out)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
module = self.model
model = module.model
if model:
encoder = self.model.get_encoder()
decoder = self.model.get_decoder()
else:
encoder = self.model.encoder
decoder = None
num_encoder_layers = len(encoder.layers)
if decoder:
num_decoder_layers = len(decoder.layers)
else:
num_decoder_layers = 0
stage_manager = self.pipeline_stage_manager
if stage_manager is not None and stage_manager.num_stages > 1:
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(num_encoder_layers, num_decoder_layers,
stage_manager.num_stages)
shared_params = []
shared_embedding = {}
if id(module.proj_out) == id(model.decoder.embed_tokens):
shared_embedding[decoder_starting_stage] = model.decoder.embed_tokens
shared_embedding[stage_manager.num_stages - 1] = module.proj_out
if len(shared_embedding) > 0:
shared_params.append(shared_embedding)
return shared_params
return []
# WhisperForAudioClassification
class WhisperForAudioClassificationPolicy(WhisperPolicy):
def __init__(self) -> None:
super().__init__()
def preprocess(self):
return self.model
def module_policy(self):
from transformers import WhisperForAudioClassification
policy = super().module_policy()
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(model_cls=WhisperForAudioClassification,
new_forward=WhisperPipelineForwards.whisper_for_audio_classification_forward,
policy=policy)
return policy
def get_held_layers(self) -> List[nn.Module]:
held_layers = super().get_held_layers()
if self.pipeline_stage_manager.is_last_stage():
held_layers.append(self.model.projector)
held_layers.append(self.model.classifier)
return held_layers
def get_shared_params(self) -> List[Dict[int, Tensor]]:
return []
......@@ -20,6 +20,8 @@ class ShardConfig:
enable_tensor_parallelism (bool): Whether to turn on tensor parallelism, default is True.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False.
enable_all_optimization (bool): Whether to turn on all optimization, default is False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism, default is False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap, default is False.
"""
tensor_parallel_process_group: Optional[ProcessGroup] = None
pipeline_stage_manager: Optional[PipelineStageManager] = None
......@@ -28,6 +30,8 @@ class ShardConfig:
enable_all_optimization: bool = False
enable_flash_attention: bool = False
enable_jit_fused: bool = False
enable_sequence_parallelism: bool = False
enable_sequence_overlap: bool = False
# pipeline_parallel_size: int
# data_parallel_size: int
......@@ -40,6 +44,11 @@ class ShardConfig:
return self._tensor_parallel_size
def __post_init__(self):
if not self.enable_tensor_parallelism and self.enable_sequence_parallelism:
raise ValueError(
"enable_sequence_parallelism can only be set to True when enable_tensor_parallelism is True")
if not self.enable_sequence_parallelism and self.enable_sequence_overlap:
raise ValueError("enable_sequence_overlap can only be set to True when enable_sequence_parallelism is True")
if not self.enable_tensor_parallelism:
self._tensor_parallel_size = 1
else:
......@@ -57,3 +66,5 @@ class ShardConfig:
self.enable_fused_normalization = True
self.enable_flash_attention = True
self.enable_jit_fused = True
self.enable_sequence_parallelism = True
self.enable_sequence_overlap = True
......@@ -92,22 +92,21 @@ class ModelSharder(object):
param_replacement (List[Callable]): The function list to get parameter shard information in policy
method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement
sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
"""
# released layers are not shardable
can_replace_param_or_layer = include is None or module in include
if (isinstance(origin_cls, str) and origin_cls == module.__class__.__name__) or \
(module.__class__ == origin_cls):
if attr_replacement is not None:
self._replace_attr(module, attr_replacement)
if param_replacement is not None and can_replace_param_or_layer:
if param_replacement is not None and (include is None or module in include):
self._replace_param(module, param_replacement)
if method_replacement is not None:
self._replace_method(module, method_replacement)
if sub_module_replacement is not None and can_replace_param_or_layer:
self._replace_sub_module(module, sub_module_replacement)
if sub_module_replacement is not None:
self._replace_sub_module(module, sub_module_replacement, include)
for name, child in module.named_children():
self._recursive_replace_layer(child,
......@@ -154,18 +153,17 @@ class ModelSharder(object):
bound_method = MethodType(new_method, module)
setattr(module, method_name, bound_method)
def _replace_sub_module(
self,
def _replace_sub_module(self,
org_layer: nn.Module,
sub_module_replacement: List[SubModuleReplacementDescription],
) -> None:
include: Optional[Set[nn.Module]] = None) -> None:
r"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Args:
org_layer (torch.nn.Module): The origin layer object to shard
sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list
include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
"""
for description in sub_module_replacement:
suffix = description.suffix
......@@ -174,9 +172,12 @@ class ModelSharder(object):
assert target_module is not None, 'target_module should not be None'
# TODO: support different parallel mode
native_sub_module = getattr_(org_layer, suffix, ignore=True)
# Skip replacement if submodule is not kept by current device when pipeline parallel is enabled.
if (include is not None) and (native_sub_module is not None) and (native_sub_module not in include):
continue
assert not isinstance(native_sub_module, target_module), \
f"The module with suffix {suffix} has been replaced, please check the policy"
......
......@@ -10,8 +10,9 @@ import torch.nn as nn
from torch.distributed import ProcessGroup
from torch.distributed.distributed_c10d import _get_default_group
from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.checkpoint_io.utils import calculate_tensor_size, StateDictSharder
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.nn.parallel.data_parallel import _cast_float, free_storage
......@@ -733,7 +734,7 @@ class GeminiDDP(ModelWrapper):
Yields:
Iterator[OrderedDict]: A generator of state dict shard
"""
sharder = _StateDictSharder(max_shard_size)
sharder = StateDictSharder(max_shard_size)
# get the mapping between copies and fp16 parameters
fp16_to_fp32 = dict()
......@@ -755,7 +756,7 @@ class GeminiDDP(ModelWrapper):
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0, dtype))
gathered_param = gathered_param_buffer.pop(fp32_param)
block, block_size = sharder.append(prefix + name, gathered_param)
block, block_size = sharder.append_param(prefix + name, gathered_param)
if block is not None:
yield block, block_size
......@@ -766,7 +767,7 @@ class GeminiDDP(ModelWrapper):
for name, buf in self.named_buffers():
if buf is not None and name not in self._non_persistent_buffers_set:
buffer = buf if keep_vars else buf.detach()
block, block_size = sharder.append(prefix + name, buffer)
block, block_size = sharder.append_param(prefix + name, buffer)
if block is not None:
yield block, block_size
# save extra states
......@@ -774,32 +775,10 @@ class GeminiDDP(ModelWrapper):
if getattr(self.__class__, "get_extra_state",
torch.nn.Module.get_extra_state) is not torch.nn.Module.get_extra_state:
extra_state = self.get_extra_state()
block, block_size = sharder.append(extra_state_key, extra_state)
block, block_size = sharder.append_param(extra_state_key, extra_state)
if block is not None:
yield block, block_size
yield sharder.current_block, sharder.current_block_size
class _StateDictSharder:
def __init__(self, max_shard_size: int) -> None:
self.max_shard_size = max_shard_size
self.current_block = OrderedDict()
self.current_block_size = 0
def append(self, name: str, tensor: torch.Tensor) -> Tuple[Optional[OrderedDict], int]:
tensor_size = calculate_tensor_size(tensor)
ret_block = None
ret_block_size = 0
# before we return the current block and create a new block,
# we need to ensure that the current block is not empty
if self.current_block_size + tensor_size > self.max_shard_size and self.current_block_size > 0:
ret_block = self.current_block
ret_block_size = self.current_block_size
self.current_block = OrderedDict()
self.current_block_size = 0
self.current_block[name] = tensor
self.current_block_size += tensor_size
return ret_block, ret_block_size
......@@ -10,7 +10,7 @@ from torch.nn import Parameter
from torch.optim import Optimizer
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import calculate_tensor_size
from colossalai.checkpoint_io.utils import calculate_tensor_size, StateDictSharder
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import CPUAdam, FusedAdam, HybridAdam
......@@ -692,49 +692,17 @@ class GeminiOptimizer(OptimizerWrapper):
Iterator[OrderedDict]: A generator of state dict shard of optimizer states.
"""
current_block = {}
current_block_size = 0
sharder = StateDictSharder(max_shard_size)
for param_id in self.id_to_real_params.keys():
dist.barrier()
state = self.collect_states(param_id=param_id, only_rank_0=only_rank_0)
ret_block = None
ret_block_size = 0
# A state might contain more than one tensors.
# e.g. each Adam state includes: 'step', 'exp_avg', 'exp_avg_sq'
state_size = 0
isDTensor = False
for state_tensor in state.values():
# When state_tensor is not of Tensor class,
# e.g., a SGD optimizer with momentum set to 0 can have None as state
# The calculation of tensor size should be skipped to avoid error.
if not isinstance(state_tensor, torch.Tensor):
continue
# If the states are stored as DTensors, mark isDTensor as true.
if is_distributed_tensor(state_tensor):
isDTensor = True
state_size += calculate_tensor_size(state_tensor)
if not isDTensor:
if current_block_size + state_size > max_shard_size and current_block_size > 0:
ret_block = current_block
ret_block_size = current_block_size
current_block = {}
current_block_size = 0
current_block[param_id] = state
current_block_size += state_size
if ret_block != None:
yield ret_block, ret_block_size
block, block_size = sharder.append_optim_state(param_id, state)
if block is not None:
yield block, block_size
yield current_block, current_block_size
yield sharder.current_block, sharder.current_block_size
def clip_grad_by_value(self, clip_value: float, *args, **kwargs) -> None:
raise NotImplementedError('Gemini does not support clip_grad_by_value')
......
......@@ -338,6 +338,24 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
self.zero_grad()
def backward_by_grad(self, tensor, grad):
assert not(self._partition_grads and not self.require_grad_sync), \
"ZeRO2(partition_grads) and gradient accumulation(no_sync) are not compatible"
if self.mixed_precision_mixin is not None:
grad = self.mixed_precision_mixin.pre_backward_by_grad(tensor, grad)
torch.autograd.backward(tensor, grad)
if not self.require_grad_sync:
return
self._reduce_grad(self._partition_grads)
# clear reduced grads
if self._overlap_communication:
torch.cuda.synchronize()
self.zero_grad()
def zero_grad(self, set_to_none=True):
"""
Set parameter gradients to zero. If set_to_none = True, gradient
......@@ -363,7 +381,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
def step(self, closure=None):
assert closure is None, 'closure is not supported by step()'
if not self.require_grad_sync:
return
......
......@@ -7,13 +7,15 @@ This directory includes two parts: Using the Booster API finetune Huggingface Be
bash test_ci.sh
```
### Results on 2-GPU
### Bert-Finetune Results
| Plugin | Accuracy | F1-score | GPU number |
| -------------- | -------- | -------- | -------- |
| torch_ddp | 84.4% | 88.6% | 2 |
| torch_ddp_fp16 | 84.7% | 88.8% | 2 |
| gemini | 84.0% | 88.4% | 2 |
| hybrid_parallel | 84.5% | 88.6% | 4 |
| Plugin | Accuracy | F1-score |
| -------------- | -------- | -------- |
| torch_ddp | 84.4% | 88.6% |
| torch_ddp_fp16 | 84.7% | 88.8% |
| gemini | 84.0% | 88.4% |
## Benchmark
```
......
import argparse
from typing import List, Union
from contextlib import nullcontext
from typing import Callable, List, Union
import evaluate
import torch
import torch.distributed as dist
import torch.nn as nn
from data import GLUEDataBuilder
from torch.optim import Optimizer
from torch.optim import Adam, Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import (
......@@ -18,8 +20,9 @@ from transformers import (
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
......@@ -32,20 +35,76 @@ LEARNING_RATE = 2.4e-5
WEIGHT_DECAY = 0.01
WARMUP_FRACTION = 0.1
output_transform_fn = lambda x: x
criterion = lambda x: x.loss
def move_to_cuda(batch):
return {k: v.cuda() for k, v in batch.items()}
@torch.no_grad()
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int,
task_name: str, eval_splits: List[str], coordinator: DistCoordinator):
def evaluate_model(
model: nn.Module,
optimizer,
criterion,
test_dataloader: Union[DataLoader, List[DataLoader]],
num_labels: int,
task_name: str,
eval_splits: List[str],
booster: Booster,
coordinator: DistCoordinator,
):
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
model.eval()
def evaluate_subset(dataloader: DataLoader):
accum_loss = torch.zeros(1, device=get_current_device())
for batch in dataloader:
batch = move_to_cuda(batch)
labels = batch["labels"]
batch_size = batch["input_ids"].shape[0]
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
pg_mesh = booster.plugin.pg_mesh
pp_group = booster.plugin.pp_group
current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
current_rank = dist.get_rank()
#TODO pass dataloader to execute_pipeline directly
batch = iter([batch])
outputs = booster.execute_pipeline(batch,
model,
criterion,
optimizer,
return_loss=True,
return_outputs=True)
if booster.plugin.stage_manager.is_last_stage():
val_loss = outputs["loss"]
logits = outputs["outputs"]["logits"]
accum_loss.add_(val_loss)
if num_labels > 1:
preds = torch.argmax(logits, axis=1)
elif num_labels == 1:
preds = logits.squeeze()
dist.broadcast(preds, src=current_rank, group=pp_group)
dist.broadcast(val_loss, src=current_rank, group=pp_group)
metric.add_batch(predictions=preds, references=labels)
elif current_rank in current_pp_group_ranks:
val_loss = torch.empty((1,), device=get_current_device())
preds = torch.empty((batch_size,), dtype=torch.int64, device=get_current_device())
dist.broadcast(preds, src=current_pp_group_ranks[-1], group=pp_group)
dist.broadcast(val_loss, src=current_pp_group_ranks[-1], group=pp_group)
accum_loss.add_(val_loss)
metric.add_batch(predictions=preds, references=labels)
else:
batch = move_to_cuda(batch)
outputs = model(**batch)
val_loss, logits = outputs[:2]
......@@ -56,14 +115,13 @@ def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[Dat
elif num_labels == 1:
preds = logits.squeeze()
labels = batch["labels"]
metric.add_batch(predictions=preds, references=labels)
results = metric.compute()
dist.all_reduce(accum_loss.div_(len(dataloader)))
if coordinator.is_master():
if coordinator.is_master() and results is not None:
results['loss'] = accum_loss.item() / coordinator.world_size
return results
if isinstance(test_dataloader, DataLoader):
......@@ -77,25 +135,43 @@ def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[Dat
return final_results
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, lr_scheduler, train_dataloader: DataLoader,
booster: Booster, coordinator: DistCoordinator):
def train_epoch(epoch: int, model: nn.Module, optimizer: Optimizer, _criterion: Callable, lr_scheduler: LRScheduler,
train_dataloader: DataLoader, booster: Booster, coordinator: DistCoordinator):
model.train()
with tqdm(train_dataloader, desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]', disable=not coordinator.is_master()) as pbar:
is_pp_last_stage = hasattr(
booster.plugin,
"stage_manager") and booster.plugin.stage_manager is not None and booster.plugin.stage_manager.is_last_stage()
with tqdm(train_dataloader,
desc=f'Epoch [{epoch + 1}/{NUM_EPOCHS}]',
disable=not (coordinator.is_master() or is_pp_last_stage)) as pbar:
for batch in pbar:
# Forward pass
batch = move_to_cuda(batch)
outputs = model(**batch)
loss = outputs[0]
if hasattr(booster.plugin, "stage_manager") and booster.plugin.stage_manager is not None:
#TODO pass train_dataloader to execute_pipeline directly
batch = iter([batch])
outputs = booster.execute_pipeline(batch,
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=True)
# Backward and optimize
if booster.plugin.stage_manager.is_last_stage():
loss = outputs['loss']
pbar.set_postfix({'loss': loss.item()})
else:
outputs = model(**batch)
loss = _criterion(outputs, None)
# Backward
booster.backward(loss, optimizer)
pbar.set_postfix({'loss': loss.item()})
optimizer.step()
optimizer.zero_grad()
lr_scheduler.step()
# Print log info
pbar.set_postfix({'loss': loss.item()})
def main():
# ==============================
......@@ -107,7 +183,7 @@ def main():
'--plugin',
type=str,
default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero'],
choices=['torch_ddp', 'torch_ddp_fp16', 'gemini', 'low_level_zero', 'hybrid_parallel'],
help="plugin to use")
parser.add_argument(
"--model_type",
......@@ -116,6 +192,7 @@ def main():
help="bert or albert",
)
parser.add_argument('--target_f1', type=float, default=None, help="target f1 score. Raise exception if not reached")
parser.add_argument('--use_lazy_init', type=bool, default=False, help="for initiating lazy init context")
args = parser.parse_args()
if args.model_type == 'bert':
......@@ -145,6 +222,17 @@ def main():
plugin = GeminiPlugin(initial_scale=2**5)
elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5)
elif args.plugin == 'hybrid_parallel':
# modify the param accordingly for finetuning test cases
plugin = HybridParallelPlugin(tp_size=1,
pp_size=2,
num_microbatches=None,
microbatch_size=1,
enable_all_optimization=True,
zero_stage=1,
precision='fp16',
initial_scale=1)
booster = Booster(plugin=plugin, **booster_kwargs)
......@@ -165,8 +253,9 @@ def main():
# bert pretrained model
cfg = AutoConfig.from_pretrained(model_name, num_labels=data_builder.num_labels)
if model_name == "bert-base-uncased":
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg)
model = BertForSequenceClassification.from_pretrained(model_name, config=cfg).cuda()
elif model_name == "albert-xxlarge-v2":
model = AlbertForSequenceClassification.from_pretrained(model_name, config=cfg)
else:
......@@ -196,19 +285,27 @@ def main():
num_training_steps=total_steps,
)
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
# ==============================
# Boost with ColossalAI
# ==============================
model, optimizer, _, _, lr_scheduler = booster.boost(model, optimizer, lr_scheduler=lr_scheduler)
model, optimizer, _criterion, _, lr_scheduler = booster.boost(model,
optimizer,
criterion=_criterion,
lr_scheduler=lr_scheduler)
# ==============================
# Train model
# ==============================
for epoch in range(NUM_EPOCHS):
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
train_epoch(epoch, model, optimizer, _criterion, lr_scheduler, train_dataloader, booster, coordinator)
results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
coordinator)
results = evaluate_model(model, optimizer, _criterion, test_dataloader, data_builder.num_labels, args.task,
data_builder.eval_splits, booster, coordinator)
if coordinator.is_master():
print(results)
......
......@@ -3,6 +3,6 @@ set -xe
pip install -r requirements.txt
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero"; do
for plugin in "torch_ddp" "torch_ddp_fp16" "gemini" "low_level_zero" "hybrid_parallel"; do
torchrun --standalone --nproc_per_node 4 finetune.py --target_f1 0.86 --plugin $plugin --model_type "bert"
done
[pytest]
markers =
cpu: tests which can run on CPU
gpu: tests which requires a single GPU
dist: tests which are run in a multi-GPU or multi-machine environment
experiment: tests for experimental features
dist: tests which are run in a multi-GPU or multi-machine environment (at least 4 GPUs)
largedist: tests which are run in a multi-GPU or multi-machine environment (at least 8 GPUs)
addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx --ignore=tests/test_legacy
......@@ -2,7 +2,7 @@ from .albert import *
from .bert import *
from .blip2 import *
from .bloom import *
from .chatglm import *
from .chatglm2 import *
from .gpt import *
from .llama import *
from .opt import *
......
......@@ -12,8 +12,8 @@ from ..registry import ModelAttribute, model_zoo
def data_gen():
input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1]])
input_ids = torch.tensor([[5941, 15, 2670, 3543, 632, 2075, 632, 2075]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]])
return dict(input_ids=input_ids, attention_mask=attention_mask)
......
import pytest
import torch
import torch.distributed as dist
from torch.optim import Adam
from utils import shared_tempdir
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import (
assert_close_loose,
check_state_dict_equal,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
# TODO (Baizhou): Add test cases for shard=False
@clear_cache_before_run()
@parameterize('shard', [True])
@parameterize('model_name', ['transformers_gpt'])
@parameterize('size_per_shard', [32])
@parameterize('test_config', [{
'tp_size': 4,
'pp_size': 1,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 2,
'pp_size': 1,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_config: dict):
(model_fn, data_gen_fn, output_transform_fn, loss_fn,
_) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = loss_fn
plugin = HybridParallelPlugin(**test_config)
booster = Booster(plugin=plugin)
def _criterion(outputs, inputs):
outputs = output_transform_fn(outputs)
loss = criterion(outputs)
return loss
def _preprocess_data(data):
if booster.plugin.stage_manager is not None:
for k, v in data.items():
if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__:
new_shape = [1] * v.dim()
new_shape[0] = 4
data[k] = v.to('cuda').repeat(*new_shape)
return iter([data])
else:
return {k: v.cuda() for k, v in data.items()}
model = model_fn().cuda()
optimizer = Adam(model.parameters(), lr=1e-3)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
data = data_gen_fn()
model.train()
if booster.plugin.stage_manager is not None:
booster.execute_pipeline(_preprocess_data(data),
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=False)
else:
output = model(**_preprocess_data(data))
loss = criterion(output)
optimizer.backward(loss)
optimizer.step()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier()
new_model = model_fn().cuda()
new_optimizer = Adam(new_model.parameters(), lr=1e-3)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
booster.load_model(new_model, model_ckpt_path)
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
check_state_dict_equal(optimizer.unwrap().state_dict(), new_optimizer.unwrap().state_dict(), False)
dist.barrier()
# Check whether the loaded model & optimizer works smoothly.
model.train()
new_model.train()
if booster.plugin.stage_manager is not None:
booster.execute_pipeline(_preprocess_data(data),
model,
_criterion,
optimizer,
return_loss=True,
return_outputs=False)
booster.execute_pipeline(_preprocess_data(data),
new_model,
_criterion,
new_optimizer,
return_loss=True,
return_outputs=False)
else:
old_model_loss = criterion(model(**_preprocess_data(data)))
optimizer.backward(old_model_loss)
new_model_loss = criterion(new_model(**_preprocess_data(data)))
new_optimizer.backward(new_model_loss)
optimizer.step()
new_optimizer.step()
# Check updated weights.
stage_manager = booster.plugin.stage_manager
if stage_manager is None or stage_manager.is_first_stage():
assert_close_loose(model.unwrap().wte.weight.data, new_model.unwrap().wte.weight.data, atol=5e-3, rtol=5e-3)
assert_close_loose(model.unwrap().h[0].mlp.c_fc.weight.data,
new_model.unwrap().h[0].mlp.c_fc.weight.data,
atol=5e-3,
rtol=5e-3)
dist.barrier()
Randomizer.reset_index()
clear_layout_converter()
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_state_dict()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [4])
@rerun_if_address_is_in_use()
def test_hybrid_ckpIO(world_size):
spawn(run_dist, world_size)
import os
import pytest
import torch
import torch.distributed as dist
from utils import shared_tempdir
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import (
check_state_dict_equal,
clear_cache_before_run,
parameterize,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
@clear_cache_before_run()
@parameterize('model_name', ['transformers_gpt'])
@parameterize('plugin_type', ['ddp', 'zero', 'gemini'])
def exam_from_pretrained(plugin_type: str, model_name: str, shard=True, size_per_shard=32):
(model_fn, data_gen_fn, output_transform_fn, loss_fn,
_) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = loss_fn
if plugin_type == 'ddp':
plugin = TorchDDPPlugin()
elif plugin_type == 'zero':
plugin = LowLevelZeroPlugin(stage=2, max_norm=1.0, initial_scale=32)
elif plugin_type == 'gemini':
plugin = GeminiPlugin(precision="fp16", initial_scale=32)
else:
raise ValueError(f"Plugin with type {plugin_type} is invalid, please check your argument.")
booster = Booster(plugin=plugin)
model = model_fn().cuda()
model_huggingface_cls = model.__class__
optimizer = HybridAdam(model.parameters(), lr=0.001)
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
data = data_gen_fn()
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
output = model(**data)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.step()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard)
dist.barrier()
new_model = model_huggingface_cls.from_pretrained(model_ckpt_path)
new_model = new_model.cuda()
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
if plugin_type == 'gemini':
check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False),
new_model.unwrap().state_dict(only_rank_0=False), False)
else:
check_state_dict_equal(model.unwrap().state_dict(), new_model.unwrap().state_dict(), False)
dist.barrier()
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
exam_from_pretrained()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [2])
@rerun_if_address_is_in_use()
def test_huggingface_compatibility(world_size):
spawn(run_dist, world_size)
......@@ -8,7 +8,6 @@ import pytest
from colossalai.context.config import Config
@pytest.mark.cpu
def test_load_config():
filename = Path(__file__).parent.joinpath('sample_config.py')
config = Config.from_file(filename)
......
......@@ -143,7 +143,6 @@ def run_dist(rank, world_size, port, backend, port_list, host):
reset_seeds()
@pytest.mark.cpu
@rerun_if_address_is_in_use()
def test_context():
"""
......
......@@ -5,11 +5,10 @@ import os
from pathlib import Path
import pytest
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
@pytest.mark.cpu
def test_cifar10_dataset():
# build transform
transform_pipeline = [transforms.ToTensor()]
......
......@@ -53,7 +53,6 @@ def run_data_sampler(rank, world_size, port):
torch.cuda.empty_cache()
@pytest.mark.cpu
@rerun_if_address_is_in_use()
def test_data_sampler():
spawn(run_data_sampler, 4)
......
......@@ -64,7 +64,6 @@ def run_data_sampler(rank, world_size, port):
torch.cuda.empty_cache()
@pytest.mark.cpu
@rerun_if_address_is_in_use()
def test_data_sampler():
spawn(run_data_sampler, 4)
......
from colossalai.shardformer.policies.t5 import T5BasePolicy
def test_t5_pipeline_distribution():
num_test_cases = 8
test_dict = {
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
}
for i in range(num_test_cases):
_, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(test_dict['num_encoder_layers'][i],
test_dict['num_decoder_layers'][i],
test_dict['num_stages'][i])
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
def test_t5_pipeline_layers():
num_test_cases = 4
test_dict = {
'num_encoder_layers': [2, 3, 2, 4],
'num_decoder_layers': [2, 0, 2, 8],
'num_stages': [2, 2, 4, 4],
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
[[0, 4], [0, 3], [3, 6], [6, 8]]]
}
for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = T5BasePolicy.distribute_t5_layers(
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
for stage in range(test_dict['num_stages'][i]):
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
predicted_start, predicted_end = T5BasePolicy.get_t5_stage_index(layers_per_stage, stage,
decoder_starting_stage)
assert start_idx == predicted_start
assert end_idx == predicted_end
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment