Commit c1d5453e authored by Frank Lee's avatar Frank Lee
Browse files

[shardformer] adapted llama to the new API (#4036)

parent 74d176c8
import importlib
from dataclasses import dataclass
import torch.nn as nn
from .basepolicy import Policy
def build_policies():
r"""
Build the policies for the model
Return:
The dict for the policies
@dataclass
class PolicyLocation:
"""
auto_policy_dict = {}
from transformers import BertModel
from .bert import BertModelPolicy
auto_policy_dict[BertModel] = BertModelPolicy
from transformers import BertForPreTraining
from .bert import BertForPretrainingPolicy
auto_policy_dict[BertForPreTraining] = BertForPretrainingPolicy
from transformers import BertLMHeadModel
from .bert import BertLMHeadModelPolicy
auto_policy_dict[BertLMHeadModel] = BertLMHeadModelPolicy
from transformers import BertForMaskedLM
from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
from transformers import BertForNextSentencePrediction
PolicyLocation describes the location of a policy class.
from .bert import BertForNextSentencePredictionPolicy
auto_policy_dict[BertForNextSentencePrediction] = BertForNextSentencePredictionPolicy
from transformers import BertForSequenceClassification
from .bert import BertForSequenceClassificationPolicy
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
from transformers.models.llama.modeling_llama import LlamaModel
Args:
file_name (str): The file name of the policy under colossalai.shardformer.policies
class_name (str): The class name of the policy class
"""
file_name: str
class_name: str
# we don't want to import all policies here
# as each policy file imports its own model zoo library
# we will allow the user to only import the policy file needed
_POLICY_LIST = {
# BERT
"transformers.models.bert.modeling_bert.BertModel":
PolicyLocation(file_name="bert", class_name="BertPolicy"),
"transformers.models.bert.modeling_bert.BertForPreTraining":
PolicyLocation(file_name="bert", class_name="BertForPretrainingPolicy"),
"transformers.models.bert.modeling_bert.BertForMaskedLM":
PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"),
"transformers.models.bert.modeling_bert.BertLMHeadModel":
PolicyLocation(file_name="bert", class_name="BertLMHeadModelPolicy"),
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction":
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
"transformers.models.bert.modeling_bert.BertForSequenceClassification":
PolicyLocation(file_name="bert", class_name="BertForSequenceClassificationPolicy"),
"transformers.models.bert.modeling_bert.BertForMultipleChoice":
PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"),
# LLaMA
"transformers.models.llama.modeling_llama.LlamaModel":
PolicyLocation(file_name="llama", class_name="LlamaPolicy"),
"transformers.models.llama.modeling_llama.LlamaForCausalLM":
PolicyLocation(file_name="llama", class_name="LlamaForCausalLMPolicy"),
"transformers.models.llama.modeling_llama.LlamaForSequenceClassification":
PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"),
# T5
# GPT2
}
def import_policy(policy_location: PolicyLocation) -> Policy:
"""
Dynamically import a Policy class based on the policy location.
"""
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module = importlib.import_module(module_name)
return getattr(module, policy_location.class_name)
# from .llama import LlamaPolicy
# auto_policy_dict[LlamaModel] = LlamaPolicy
# from transformers import LlamaForSequenceClassification
# from .llama import LlamaForSequenceClassificationPolicy
# auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy
# from transformers import LlamaForCausalLM
# from .llama import LlamaForCausalLMPolicy
# auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
# from transformers import GPT2Model
# from .gpt2 import GPT2Policy
# auto_policy_dict[GPT2Model] = GPT2Policy
# from transformers import GPT2LMHeadModel
# from .gpt2 import GPT2LMHeadModelPolicy
# auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
return auto_policy_dict
def _fullname(obj):
"""
Return the full name of an object, including the module name.
"""
klass = obj.__class__
module = klass.__module__
if module == 'builtins':
return klass.__qualname__ # avoid outputs like 'builtins.str'
return module + '.' + klass.__qualname__
def get_autopolicy(model: nn.Module) -> Policy:
......@@ -71,16 +83,14 @@ def get_autopolicy(model: nn.Module) -> Policy:
Return:
:class:`Policy`: The auto policy for the model
"""
auto_policy_dict = build_policies()
policy = auto_policy_dict.get(model.__class__, None)
if policy is None:
full_name = _fullname(model)
policy_location = _POLICY_LIST.get(full_name, None)
if policy_location is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location)
return policy()
return policy()
# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
# model = BertForPreTraining
# policy = get_autopolicy(model)
# print(policy)
......@@ -75,6 +75,7 @@ class Policy(ABC):
"""
def __init__(self) -> None:
self.shard_config = None
self.model = None
self.shard_config = None
......@@ -101,6 +102,7 @@ class Policy(ABC):
r"""
Perform some preprocessing of the model, like reshaping the embedding layer
"""
pass
@abstractmethod
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
......@@ -135,6 +137,7 @@ class Policy(ABC):
...
}
"""
pass
@abstractmethod
def new_model_class(self) -> Union[Type[nn.Module], None]:
......@@ -149,6 +152,7 @@ class Policy(ABC):
return BertModel_
```
"""
pass
@abstractmethod
def postprocess(self) -> nn.Module:
......@@ -156,3 +160,4 @@ class Policy(ABC):
Perform some postprocessing of the model, like binding the weight of embedding layer with
the classifier layer
"""
pass
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type
from typing import Dict, Union
import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
import colossalai.shardformer.layer.layers as col_nn
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from .basepolicy import Argument, Col_Layer, Policy, Row_Layer
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class LlamaPolicy(Policy):
@staticmethod
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
def preprocess(self):
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
return {
LlamaDecoderLayer:
Argument(attr_dict={
"self_attn.hidden_size": config.hidden_size // world_size,
"self_attn.num_heads": config.num_attention_heads // world_size,
},
param_funcs=[LlamaPolicy.attn_layer, LlamaPolicy.mlp_layer]),
ModulePolicyDescription(
attribute_replacement={
"self_attn.hidden_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads":
self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attn.q_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.k_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.v_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attn.o_proj",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="mlp.gate_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.up_proj",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.down_proj",
target_module=Linear1D_Row,
)
],
),
LlamaModel:
Argument(attr_dict={}, param_funcs=[LlamaPolicy.embeddings])
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
])
}
@staticmethod
def attn_layer() -> List:
return [
Col_Layer(
suffix="self_attn.q_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
suffix="self_attn.k_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Col_Layer(
suffix="self_attn.v_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
),
Row_Layer(
suffix="self_attn.o_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
)
]
@staticmethod
def mlp_layer() -> List:
return [
Col_Layer(
suffix="mlp.gate_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
),
Col_Layer(
suffix="mlp.up_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Row,
gather_output=True,
),
Col_Layer(
suffix="mlp.down_proj",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True,
),
]
@staticmethod
def embeddings() -> List:
return [Col_Layer(
suffix="embed_tokens",
weight="weight",
replace_layer=col_nn.VocabParallelEmbedding1D,
)]
from transformers import LlamaForCausalLM
class LlamaForCausalLMPolicy(LlamaPolicy):
def new_model_class(self):
return None
@staticmethod
def argument(config, world_size):
llamapolicy = LlamaPolicy.argument_policy(config, world_size)
argument = {LlamaForCausalLM: Argument(attr_dict={}, param_funcs=[LlamaForCausalLMPolicy.lm_head])}
argument.update(llamapolicy)
def postprocess(self):
return self.model
@staticmethod
def lm_head() -> List:
return [Col_Layer(suffix="lm_head", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)]
class LlamaForCausalLMPolicy(LlamaPolicy):
from transformers import LlamaForSequenceClassification
def module_policy(self):
policy = super().module_policy()
# add a new item for casual lm
new_item = {
LlamaForCausalLM:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
}
policy.update(new_item)
return policy
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
@staticmethod
def argument(config, world_size):
llamapolicy = LlamaPolicy.argument_policy(config, world_size)
argument = {
def module_policy(self):
policy = super().module_policy()
# add a new item for sequence classification
new_item = {
LlamaForSequenceClassification:
Argument(attr_dict={}, param_funcs=[LlamaForSequenceClassificationPolicy.score])
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="score",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
}
argument.update(llamapolicy)
@staticmethod
def score() -> List:
return [Col_Layer(suffix="score", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)]
policy.update(new_item)
return policy
from dataclasses import dataclass
from typing import List, Literal
from colossalai.cluster.dist_coordinator import DistCoordinator
__all__ = ['ShardConfig']
......@@ -19,9 +20,19 @@ class ShardConfig:
gather_output (bool): Whether to gather the output of the model of the last layer
"""
tensor_parallel_size: int
# TODO: add support for tensor parallel
# pipeline_parallel_size: int
# data_parallel_size: int
tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
inference_only: bool = True
gather_output: bool = True
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
# inference_only: bool = True
# gather_output: bool = True
def __post_init__(self):
coordinator = DistCoordinator()
# ensure the parallel size can match the world size
world_size = coordinator.world_size
self.data_parallel_size = world_size // self.tensor_parallel_size
assert world_size == self.data_parallel_size * self.tensor_parallel_size, \
f"The world size ({world_size}) should be divisible by the data parallel size {self.data_parallel_size} and tensor parallel size {self.tensor_parallel_size}"
from typing import Any, Callable, Dict, List
import torch
import torch.nn as nn
from transformers.pytorch_utils import Conv1D
from colossalai.cluster.process_group_manager import ProcessGroupManager
......@@ -41,10 +39,10 @@ class ModelSharder(object):
"""
self.policy.set_model(self.model)
self.policy.set_shard_config(self.shard_config)
self.preprocess()
self.replace_model_class()
self.replace_module()
self.postprocess()
self._preprocess()
self._replace_model_class()
self._replace_module()
self._postprocess()
def reshape_embedding(self) -> None:
r"""
......@@ -57,13 +55,13 @@ class ModelSharder(object):
self.model.resize_token_embeddings(new_vocab_size)
self.model_config = self.model.config
def preprocess(self) -> None:
def _preprocess(self) -> None:
self.model = self.policy.preprocess()
def postprocess(self) -> None:
def _postprocess(self) -> None:
self.model = self.policy.postprocess()
def replace_model_class(self) -> None:
def _replace_model_class(self,) -> None:
r"""
Replace the model to policy defined model
Mainly modify the forward and backward to fit distributed model
......@@ -84,7 +82,7 @@ class ModelSharder(object):
getattr(new_model_class, key),
)
def replace_module(self) -> None:
def _replace_module(self,) -> None:
r"""
Replace the module according to the policy, and replace the module one by one
......
......@@ -47,10 +47,12 @@ class ShardFormer:
"""
Initialize the distributed process group according to the
"""
# create process group manager and 1d process group
# TODO: may need to support other parallel mode when the config has such as field
pg_manager = ProcessGroupManager()
if (self.shard_config.tensor_parallel_mode == '1d'):
pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size))
pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size))
self.pg_manager = pg_manager
return pg_manager
def shard_model(self, model: nn.Module, policy: Policy = None):
......
......@@ -24,21 +24,18 @@ CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def build_model(rank, world_size, model):
config = BertConfig.from_pretrained('bert-base-uncased')
def build_model(world_size, model_fn):
config = BertConfig()
config.hidden_dropout_prob = 0
config.attention_probs_dropout_prob = 0
org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
org_model = model_fn(config=config)
org_model_forshard = copy.deepcopy(org_model)
org_model.to('cuda')
# TODO: no need to transfer to cuda
org_model_forshard.to('cuda')
shard_config = ShardConfig(
tensor_parallel_size=2,
tensor_parallel_mode='1d',
)
shard_config = ShardConfig(tensor_parallel_size=world_size,)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
......@@ -99,15 +96,22 @@ def check_bert(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
forward_list = [
BertModel, BertForPreTraining, BertForMaskedLM, BertLMHeadModel, BertForNextSentencePrediction,
BertForSequenceClassification
BertForMaskedLM,
BertForPreTraining,
BertLMHeadModel,
# TODO: do not work yet
# BertModel,
# BertForSequenceClassification
# BertForNextSentencePrediction,
]
backward_lsit = [BertForMaskedLM, BertLMHeadModel]
for model in forward_list:
org_model, sharded_model = build_model(rank, world_size, model)
for model_fn in forward_list:
org_model, sharded_model = build_model(model_fn)
check_forward(org_model, sharded_model)
if model in backward_lsit:
if model_fn in backward_lsit:
check_backward(org_model, sharded_model)
torch.cuda.empty_cache()
......
......@@ -4,31 +4,28 @@ import random
import pytest
import torch
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, LlamaModel, LlamaTokenizerFast
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaTokenizerFast
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=4, mode='1d')),)
tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
def build_model(rank, world_size):
cfg = LlamaConfig(num_hidden_layers=16)
org_model = LlamaForCausalLM(cfg)
def build_model(world_size, model_fn):
# create new model
config = LlamaConfig(num_hidden_layers=8)
org_model = model_fn(config).cuda()
shardconfig = ShardConfig(
rank=rank,
world_size=world_size,
gather_output=True,
)
org_model = org_model.to('cuda')
org_model_forshard = copy.deepcopy(org_model)
sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda')
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy)
return org_model, sharded_model
......@@ -38,6 +35,7 @@ def check_forward(org_model, sharded_model):
inputs = tokenizer(input, return_tensors='pt').to('cuda')
del inputs["token_type_ids"]
del inputs["attention_mask"]
#orgin model
org_model.eval()
org_out = org_model(**inputs)
......@@ -87,11 +85,20 @@ def check_backward(org_model, sharded_model):
def check_llama(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model_list = [
LlamaForCausalLM,
# TODO: do not work yet
# LlamaModel,
# LlamaForSequenceClassification
]
org_model, sharded_model = build_model(rank, world_size)
check_forward(org_model, sharded_model)
check_backward(org_model, sharded_model)
for model_fn in model_list:
org_model, sharded_model = build_model(world_size, model_fn)
check_forward(org_model, sharded_model)
check_backward(org_model, sharded_model)
torch.cuda.empty_cache()
......
......@@ -8,7 +8,7 @@ from transformers import AutoTokenizer, BertConfig, BertForMaskedLM, T5Config, T
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.shardformer.shard import ShardConfig, ShardFormer
from colossalai.testing import rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
......@@ -90,6 +90,7 @@ def check_t5(rank, world_size, port):
@pytest.mark.dist
@pytest.mark.skip
@rerun_if_address_is_in_use()
def test_t5():
spawn(check_t5, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment