Commit dfca9678 authored by FoolPlayer's avatar FoolPlayer Committed by Frank Lee
Browse files

integrate with dist layer (#4011)

parent 015af592
...@@ -8,12 +8,6 @@ from ..utils import getattr_, setattr_ ...@@ -8,12 +8,6 @@ from ..utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
class ParallelModule():
def __init__(self):
pass
class BertPolicy(Policy): class BertPolicy(Policy):
def preprocess(self, shard_config: ShardConfig = None): def preprocess(self, shard_config: ShardConfig = None):
...@@ -49,7 +43,27 @@ class BertPolicy(Policy): ...@@ -49,7 +43,27 @@ class BertPolicy(Policy):
sub_module_replacement=[ sub_module_replacement=[
SubModuleReplacementDescription( SubModuleReplacementDescription(
suffix="attention.self.query", suffix="attention.self.query",
target_module=ParallelModule, target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.key",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.self.value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=col_nn.Linear1D_Row,
), ),
]) ])
} }
......
...@@ -7,8 +7,8 @@ from transformers.pytorch_utils import Conv1D ...@@ -7,8 +7,8 @@ from transformers.pytorch_utils import Conv1D
from colossalai.cluster.process_group_manager import ProcessGroupManager from colossalai.cluster.process_group_manager import ProcessGroupManager
from ..policies.autopolicy import get_autopolicy from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Policy from ..policies.basepolicy import Policy, SubModuleReplacementDescription
from ..utils.utils import setattr_ from ..utils.utils import getattr_, setattr_
from .shard_config import ShardConfig from .shard_config import ShardConfig
__all__ = ['ModelSharder', 'shard_model'] __all__ = ['ModelSharder', 'shard_model']
...@@ -90,9 +90,7 @@ class ModelSharder(object): ...@@ -90,9 +90,7 @@ class ModelSharder(object):
Args: Args:
model (:class:`torch.nn.Module`): The model to shard model (:class:`torch.nn.Module`): The model to shard
""" """
print(self.policy)
module_descriptions = self.policy.module_policy(self.shard_config) module_descriptions = self.policy.module_policy(self.shard_config)
print(f"*******{module_descriptions}")
for module_description in module_descriptions.items(): for module_description in module_descriptions.items():
origin_layer_cls = module_description[0] origin_layer_cls = module_description[0]
attr_replacement = module_description[1].attribute_replacement attr_replacement = module_description[1].attribute_replacement
...@@ -160,7 +158,7 @@ class ModelSharder(object): ...@@ -160,7 +158,7 @@ class ModelSharder(object):
def _replace_sub_module( def _replace_sub_module(
self, self,
org_layer: nn.Module, org_layer: nn.Module,
sub_module_replacement: List[Callable], sub_module_replacement: List[SubModuleReplacementDescription],
) -> None: ) -> None:
r""" r"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
...@@ -177,7 +175,8 @@ class ModelSharder(object): ...@@ -177,7 +175,8 @@ class ModelSharder(object):
assert target_module is not None, 'target_module should not be None' assert target_module is not None, 'target_module should not be None'
# TODO: integrate with new layer # TODO: support different parallel mode
# replace_layer = target_module.from_native_layer(org_layer, self.pg_manager) native_sub_module = getattr_(org_layer, suffix)
replace_layer = None replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'])
setattr_(org_layer, suffix, replace_layer) setattr_(org_layer, suffix, replace_layer)
...@@ -17,7 +17,7 @@ from transformers import ( ...@@ -17,7 +17,7 @@ from transformers import (
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.shard import ShardConfig, shard_model from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
...@@ -30,16 +30,21 @@ def build_model(rank, world_size, model): ...@@ -30,16 +30,21 @@ def build_model(rank, world_size, model):
config.hidden_dropout_prob = 0 config.hidden_dropout_prob = 0
config.attention_probs_dropout_prob = 0 config.attention_probs_dropout_prob = 0
org_model = model(config=config) org_model = BertForMaskedLM.from_pretrained('bert-base-uncased', config=config)
org_model_forshard = copy.deepcopy(org_model) org_model_forshard = copy.deepcopy(org_model)
org_model = org_model.to('cuda') org_model.to('cuda')
shardconfig = ShardConfig( # TODO: no need to transfer to cuda
rank=rank, org_model_forshard.to('cuda')
world_size=world_size, shard_config = ShardConfig(tensor_parallel_size=2,
gather_output=True, data_parallel_size=1,
) pipeline_parallel_size=1,
sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda') tensor_parallel_mode='1d',
inference_only=True,
gather_output=True)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(org_model_forshard).to('cuda')
return org_model, sharded_model return org_model, sharded_model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment