Commit 74d176c8 authored by FoolPlayer's avatar FoolPlayer Committed by Frank Lee
Browse files

[shardformer] fix bert and gpt downstream with new api (#4024)

* fix bert downstream with new api

* remove comment line
parent e253a070
...@@ -76,6 +76,7 @@ class Policy(ABC): ...@@ -76,6 +76,7 @@ class Policy(ABC):
def __init__(self) -> None: def __init__(self) -> None:
self.model = None self.model = None
self.shard_config = None
def set_model(self, model: nn.Module) -> None: def set_model(self, model: nn.Module) -> None:
r""" r"""
...@@ -86,14 +87,23 @@ class Policy(ABC): ...@@ -86,14 +87,23 @@ class Policy(ABC):
""" """
self.model = model self.model = model
def set_shard_config(self, shard_config: ShardConfig) -> None:
r"""
Set shard config as an attribute of the Policy object.
Args:
shard_config (:class:`ShardConfig`): The shard config to be perform
"""
self.shard_config = shard_config
@abstractmethod @abstractmethod
def preprocess(self, shard_config: ShardConfig = None) -> nn.Module: def preprocess(self) -> nn.Module:
r""" r"""
Perform some preprocessing of the model, like reshaping the embedding layer Perform some preprocessing of the model, like reshaping the embedding layer
""" """
@abstractmethod @abstractmethod
def module_policy(self, shard_config: ShardConfig = None) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
r""" r"""
Return the dict for the modify policy, the key is the original layer class and the value is the Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer argument for the modify layer
......
...@@ -4,41 +4,40 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be ...@@ -4,41 +4,40 @@ from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer, Be
import colossalai.shardformer.layer.layers as col_nn import colossalai.shardformer.layer.layers as col_nn
from colossalai.shardformer.layer.dropout import Dropout1D from colossalai.shardformer.layer.dropout import Dropout1D
from ..shard.shard_config import ShardConfig
from ..utils import getattr_, setattr_ from ..utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class BertPolicy(Policy): class BertPolicy(Policy):
def preprocess(self, shard_config: ShardConfig = None): def preprocess(self):
# reshape the embedding layer # reshape the embedding layer
r""" r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size Reshape the Embedding layer to make the embedding dimension divisible by world_size
""" """
# TODO: # TODO:
vocab_size = self.model.config.vocab_size vocab_size = self.model.config.vocab_size
world_size = shard_config.tensor_parallel_size world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0: if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size) self.model.resize_token_embeddings(new_vocab_size)
return self.model return self.model
def module_policy(self, shard_config: ShardConfig = None): def module_policy(self):
return { return {
BertLayer: BertLayer:
ModulePolicyDescription( ModulePolicyDescription(
attribute_replacement={ attribute_replacement={
# 1. shard hidden size # 1. shard hidden size
"attention.self.all_head_size": "attention.self.all_head_size":
self.model.config.hidden_size // shard_config.tensor_parallel_size, self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"crossattention.self.all_head_size": "crossattention.self.all_head_size":
self.model.config.hidden_size // shard_config.tensor_parallel_size, self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
# 2. shard number of heads # 2. shard number of heads
"attention.self.num_attention_heads": "attention.self.num_attention_heads":
self.model.config.num_attention_heads // shard_config.tensor_parallel_size, self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
"crossattention.self.num_attention_heads": "crossattention.self.num_attention_heads":
self.model.config.num_attention_heads // shard_config.tensor_parallel_size, self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
}, },
param_replacement=[], param_replacement=[],
sub_module_replacement=[ sub_module_replacement=[
...@@ -100,13 +99,43 @@ class BertPolicy(Policy): ...@@ -100,13 +99,43 @@ class BertPolicy(Policy):
return self.model return self.model
# BertModel
class BertModelPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
# BertForPreTraining
class BertForPretrainingPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
def module_policy(self):
module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
}
module_policy.update(addon_module)
return module_policy
# BertForMaskedLM
class BertForMaskedLMPolicy(BertPolicy): class BertForMaskedLMPolicy(BertPolicy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
def module_policy(self, shard_config: ShardConfig = None): def module_policy(self):
module_policy = super().module_policy(shard_config) module_policy = super().module_policy()
addon_module = { addon_module = {
BertLMPredictionHead: BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={}, ModulePolicyDescription(attribute_replacement={},
...@@ -124,16 +153,41 @@ class BertForMaskedLMPolicy(BertPolicy): ...@@ -124,16 +153,41 @@ class BertForMaskedLMPolicy(BertPolicy):
# BertLMHeadModel # BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy): class BertLMHeadModelPolicy(BertPolicy):
@staticmethod def __init__(self) -> None:
def argument_policy(config, world_size): super().__init__()
base_argument = BertPolicy.argument_policy(config, world_size)
argument = { def module_policy(self):
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ module_policy = super().module_policy()
BertPolicy.unembedding, addon_module = {
]), BertLMPredictionHead:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True})
])
} }
argument.update(base_argument) module_policy.update(addon_module)
return argument return module_policy
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):
def __init__(self) -> None:
super().__init__()
# BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
...@@ -18,10 +18,10 @@ class ShardConfig: ...@@ -18,10 +18,10 @@ class ShardConfig:
will not calculate the loss and just return the output. will not calculate the loss and just return the output.
gather_output (bool): Whether to gather the output of the model of the last layer gather_output (bool): Whether to gather the output of the model of the last layer
""" """
data_parallel_size: int
tensor_parallel_size: int tensor_parallel_size: int
# TODO: add support for tensor parallel
pipeline_parallel_size: int # pipeline_parallel_size: int
# data_parallel_size: int
tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d'] tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
inference_only: bool = True inference_only: bool = True
gather_output: bool = True gather_output: bool = True
...@@ -40,6 +40,7 @@ class ModelSharder(object): ...@@ -40,6 +40,7 @@ class ModelSharder(object):
Shard the model according to the policy Shard the model according to the policy
""" """
self.policy.set_model(self.model) self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self.preprocess() self.preprocess()
self.replace_model_class() self.replace_model_class()
self.replace_module() self.replace_module()
...@@ -57,12 +58,12 @@ class ModelSharder(object): ...@@ -57,12 +58,12 @@ class ModelSharder(object):
self.model_config = self.model.config self.model_config = self.model.config
def preprocess(self) -> None: def preprocess(self) -> None:
self.model = self.policy.preprocess(self.shard_config) self.model = self.policy.preprocess()
def postprocess(self) -> None: def postprocess(self) -> None:
self.model = self.policy.postprocess() self.model = self.policy.postprocess()
def replace_model_class(self,) -> None: def replace_model_class(self) -> None:
r""" r"""
Replace the model to policy defined model Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model Mainly modify the forward and backward to fit distributed model
...@@ -83,14 +84,14 @@ class ModelSharder(object): ...@@ -83,14 +84,14 @@ class ModelSharder(object):
getattr(new_model_class, key), getattr(new_model_class, key),
) )
def replace_module(self,) -> None: def replace_module(self) -> None:
r""" r"""
Replace the module according to the policy, and replace the module one by one Replace the module according to the policy, and replace the module one by one
Args: Args:
model (:class:`torch.nn.Module`): The model to shard model (:class:`torch.nn.Module`): The model to shard
""" """
module_descriptions = self.policy.module_policy(self.shard_config) module_descriptions = self.policy.module_policy()
for module_description in module_descriptions.items(): for module_description in module_descriptions.items():
origin_layer_cls = module_description[0] origin_layer_cls = module_description[0]
attr_replacement = module_description[1].attribute_replacement attr_replacement = module_description[1].attribute_replacement
......
...@@ -25,11 +25,7 @@ class ShardFormer: ...@@ -25,11 +25,7 @@ class ShardFormer:
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased') org_model = BertForMaskedLM.from_pretrained('bert-base-uncased')
shard_config = ShardConfig( shard_config = ShardConfig(
tensor_parallel_size=2, tensor_parallel_size=2,
data_parallel_size=1,
pipeline_parallel_size=1,
tensor_parallel_mode='1d', tensor_parallel_mode='1d',
inference_only=True,
gather_output=True
) )
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed() shard_former.init_distributed()
......
...@@ -7,7 +7,6 @@ from transformers import ( ...@@ -7,7 +7,6 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
BertConfig, BertConfig,
BertForMaskedLM, BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction, BertForNextSentencePrediction,
BertForPreTraining, BertForPreTraining,
BertForSequenceClassification, BertForSequenceClassification,
...@@ -36,12 +35,10 @@ def build_model(rank, world_size, model): ...@@ -36,12 +35,10 @@ def build_model(rank, world_size, model):
org_model.to('cuda') org_model.to('cuda')
# TODO: no need to transfer to cuda # TODO: no need to transfer to cuda
org_model_forshard.to('cuda') org_model_forshard.to('cuda')
shard_config = ShardConfig(tensor_parallel_size=2, shard_config = ShardConfig(
data_parallel_size=1, tensor_parallel_size=2,
pipeline_parallel_size=1, tensor_parallel_mode='1d',
tensor_parallel_mode='1d', )
inference_only=True,
gather_output=True)
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed() shard_former.init_distributed()
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda') sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment