Commit f7774ec0 authored by FoolPlayer's avatar FoolPlayer Committed by Frank Lee
Browse files

[Shardformer] Downstream bert (#3979)

* add dist dropout in model

* update docstring and bert policy with dropout

* refactor basepolicy and sharded, update bert

* update format

* update gpt2 policy

* update bert policy

* remove unused code

* update readme for new policy usage

* add downstream model of bert

* remove unused code
parent c1c672d0
...@@ -10,11 +10,31 @@ def build_policies(): ...@@ -10,11 +10,31 @@ def build_policies():
""" """
auto_policy_dict = {} auto_policy_dict = {}
from transformers import BertModel
from .bert import BertModelPolicy
auto_policy_dict[BertModel] = BertModelPolicy
from transformers import BertForPreTraining
from .bert import BertForPretrainingPolicy
auto_policy_dict[BertForPreTraining] = BertForPretrainingPolicy
from transformers import BertLMHeadModel
from .bert import BertLMHeadModelPolicy
auto_policy_dict[BertLMHeadModel] = BertLMHeadModelPolicy
from transformers import BertForMaskedLM from transformers import BertForMaskedLM
from .bert import BertForMaskedLMPolicy from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
from transformers import BertForNextSentencePrediction
from .bert import BertForNextSentencePredictionPolicy
auto_policy_dict[BertForNextSentencePrediction] = BertForNextSentencePredictionPolicy
from transformers import BertForSequenceClassification from transformers import BertForSequenceClassification
from .bert import BertForSequenceClassificationPolicy from .bert import BertForSequenceClassificationPolicy
...@@ -34,6 +54,11 @@ def build_policies(): ...@@ -34,6 +54,11 @@ def build_policies():
from .llama import LlamaForCausalLMPolicy from .llama import LlamaForCausalLMPolicy
auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
from transformers import BertForMultipleChoice
from .bert import BertForMultipleChoicePolicy
auto_policy_dict[BertForMultipleChoice] = BertForMultipleChoicePolicy
from transformers import GPT2Model from transformers import GPT2Model
from .gpt2 import GPT2Policy from .gpt2 import GPT2Policy
......
...@@ -35,12 +35,6 @@ class BertPolicy(Policy): ...@@ -35,12 +35,6 @@ class BertPolicy(Policy):
]), ]),
} }
@staticmethod
def binding_policy():
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
@staticmethod @staticmethod
def attn_in(): def attn_in():
return [ return [
...@@ -148,9 +142,53 @@ class BertPolicy(Policy): ...@@ -148,9 +142,53 @@ class BertPolicy(Policy):
replace_layer=col_nn.VocabParallelEmbedding1D, replace_layer=col_nn.VocabParallelEmbedding1D,
)] )]
@staticmethod
def unembedding():
return [
Col_Layer(
suffix="decoder",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
)
]
# BertModel
class BertModelPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
return BertPolicy.argument_policy(config, world_size)
from transformers import BertForMaskedLM # BertForPretraining
class BertForPretrainingPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
base_argument = BertPolicy.argument_policy(config, world_size)
argument = {
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
BertPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument
@staticmethod
def inject_policy():
return None
@staticmethod
def binding_policy():
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
# BertForMaskedLM
from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_ from colossalai.shardformer.model.modeling_bert import BertForMaskedLM_
...@@ -161,7 +199,7 @@ class BertForMaskedLMPolicy(BertPolicy): ...@@ -161,7 +199,7 @@ class BertForMaskedLMPolicy(BertPolicy):
base_argument = BertPolicy.argument_policy(config, world_size) base_argument = BertPolicy.argument_policy(config, world_size)
argument = { argument = {
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[ BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
BertForMaskedLMPolicy.unembedding, BertPolicy.unembedding,
]), ]),
} }
argument.update(base_argument) argument.update(base_argument)
...@@ -173,20 +211,56 @@ class BertForMaskedLMPolicy(BertPolicy): ...@@ -173,20 +211,56 @@ class BertForMaskedLMPolicy(BertPolicy):
return None return None
@staticmethod @staticmethod
def unembedding(): def binding_policy():
return [ return {
Col_Layer( "bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
suffix="decoder", }
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
)
]
class BertForSequenceClassificationPolicy(BertPolicy): # BertLMHeadModel
class BertLMHeadModelPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
base_argument = BertPolicy.argument_policy(config, world_size)
argument = {
BertLMPredictionHead: Argument(attr_dict={}, param_funcs=[
BertPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument
@staticmethod @staticmethod
def inject_policy(): def inject_policy():
return None return None
@staticmethod
def binding_policy():
return {
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
# BertForNextSentencePrediction
class BertForNextSentencePredictionPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
return BertPolicy.argument_policy(config, world_size)
# BertForSequenceClassification
class BertForSequenceClassificationPolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
return BertPolicy.argument_policy(config, world_size)
# BertForMultipleChoice
class BertForMultipleChoicePolicy(BertPolicy):
@staticmethod
def argument_policy(config, world_size):
return BertPolicy.argument_policy(config, world_size)
...@@ -13,6 +13,6 @@ class ShardConfig: ...@@ -13,6 +13,6 @@ class ShardConfig:
world_size (int): The world size of the distributed process world_size (int): The world size of the distributed process
gather_output (bool): Whether to gather the output of the model of the last layer gather_output (bool): Whether to gather the output of the model of the last layer
""" """
rank: int rank: int = None
world_size: int = 2 world_size: int = None
gather_output: bool = True gather_output: bool = True
...@@ -276,6 +276,7 @@ def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Poli ...@@ -276,6 +276,7 @@ def shard_model(model: nn.Module, shard_config: ShardConfig = None, policy: Poli
shard_config (`ShardConfig`): the config for distribute information shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding policy (`Policy`): the custom policy for sharding
""" """
# TODO: init shard_config automatically
sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy) sharder = ModelSharder(model=model, shard_config=shard_config, policy=policy)
sharder.shard() sharder.shard()
return model return model
import copy
import os import os
import random
import pytest import pytest
import torch import torch
from transformers import AutoTokenizer, BertConfig, BertForMaskedLM from transformers import (
AutoTokenizer,
BertConfig,
BertForMaskedLM,
BertForMultipleChoice,
BertForNextSentencePrediction,
BertForPreTraining,
BertForSequenceClassification,
BertLMHeadModel,
BertModel,
)
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
...@@ -15,20 +25,21 @@ CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')), ...@@ -15,20 +25,21 @@ CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def build_model(rank, world_size): def build_model(rank, world_size, model):
config = BertConfig.from_pretrained('bert-base-uncased') config = BertConfig.from_pretrained('bert-base-uncased')
config.hidden_dropout_prob = 0 config.hidden_dropout_prob = 0
config.attention_probs_dropout_prob = 0 config.attention_probs_dropout_prob = 0
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config).to('cuda') org_model = model(config=config)
org_model_forshard = copy.deepcopy(org_model)
org_model = org_model.to('cuda')
shardconfig = ShardConfig( shardconfig = ShardConfig(
rank=rank, rank=rank,
world_size=world_size, world_size=world_size,
gather_output=True, gather_output=True,
) )
sharded_model = shard_model(BertForMaskedLM.from_pretrained('bert-base-uncased', config=config), sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda')
shardconfig).to('cuda')
return org_model, sharded_model return org_model, sharded_model
...@@ -85,12 +96,19 @@ def check_backward(org_model, sharded_model): ...@@ -85,12 +96,19 @@ def check_backward(org_model, sharded_model):
def check_bert(rank, world_size, port): def check_bert(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
forward_list = [
org_model, sharded_model = build_model(rank, world_size) BertModel, BertForPreTraining, BertForMaskedLM, BertLMHeadModel, BertForNextSentencePrediction,
check_forward(org_model, sharded_model) BertForSequenceClassification
check_backward(org_model, sharded_model) ]
backward_lsit = [BertForMaskedLM, BertLMHeadModel]
torch.cuda.empty_cache()
for model in forward_list:
org_model, sharded_model = build_model(rank, world_size, model)
check_forward(org_model, sharded_model)
if model in backward_lsit:
check_backward(org_model, sharded_model)
torch.cuda.empty_cache()
@pytest.mark.dist @pytest.mark.dist
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment