Commit 79f8d5d5 authored by FoolPlayer's avatar FoolPlayer Committed by Frank Lee
Browse files

[shardformer] add gpt2 policy and modify shard and slicer to support (#3883)

* add gpt2 policy and modify shard and slicer to support

* remove unused code

* polish code
parent 70173e31
......@@ -10,16 +10,26 @@ def build_policies():
"""
auto_policy_dict = {}
from transformers.models.bert.modeling_bert import BertForMaskedLM
from transformers import BertForMaskedLM
from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy
from transformers.models.bert.modeling_bert import BertForSequenceClassification
from transformers import BertForSequenceClassification
from .bert import BertForSequenceClassificationPolicy
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
from transformers import GPT2Model
from .gpt2 import GPT2Policy
auto_policy_dict[GPT2Model] = GPT2Policy
from transformers import GPT2LMHeadModel
from .gpt2 import GPT2LMHeadModelPolicy
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
return auto_policy_dict
......
# part of code modified from https://github.com/tunib-ai/parallelformers
from dataclasses import dataclass, field
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type
import torch
import torch.nn as nn
from transformers import AutoConfig
@dataclass
......@@ -31,11 +29,18 @@ class Layer:
bias (str): The bias suffix of the layer
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
ignore (bool): Whether to ignore this layer if it is not in the model
reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in],
but in GPT2 `Conv1D` layer is [in, out] which is reversed.
n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices,
but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and
each device should have a part of Q, K and V weight.
"""
weight: str = None
bias: str = None
replace_layer: Any = None
ignore: bool = False
reversed: bool = False
n_cast: int = None
@dataclass
......@@ -131,7 +136,7 @@ class Policy():
(OrignModel, CustomModel)
in `CustomModel`, we can overwrite the forward and backward process
"""
return ()
return None
@staticmethod
def binding_policy() -> Dict:
......@@ -146,7 +151,7 @@ class Policy():
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
"""
return NotImplementedError
return None
@staticmethod
def attn_in() -> List:
......@@ -209,4 +214,4 @@ class Policy():
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
return None
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type
import torch.nn as nn
......
from typing import Any, Callable, Dict, List, Tuple, Type
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
import colossalai.shardformer.layer.layers as col_nn
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
class GPT2Policy(Policy):
@staticmethod
def argument_policy(config, world_size):
return {
GPT2Model:
Argument(attr_dict={}, param_funcs=[
GPT2Policy.embedding,
]),
GPT2Block:
Argument(
attr_dict={
# 1. reduce hidden size
"attn.embed_dim": config.hidden_size // world_size,
"attn.split_size": config.hidden_size // world_size,
"crossattention.embed_dim": config.hidden_size // world_size,
"crossattention.split_size": config.hidden_size // world_size,
# 2. reduce number of heads
"attn.num_heads": config.num_attention_heads // world_size,
"crossattention.num_heads": config.num_attention_heads // world_size,
},
param_funcs=[
GPT2Policy.attn_in,
GPT2Policy.attn_out,
GPT2Policy.mlp_in,
GPT2Policy.mlp_out,
]),
}
@staticmethod
def attn_in() -> List:
return [
Col_Layer(weight="attn.c_attn.weight",
bias="attn.c_attn.bias",
n_cast=3,
reversed=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(weight="crossattention.c_attn.weight",
bias="crossattention.c_attn.bias",
n_cast=2,
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(weight="crossattention.q_attn.weight",
bias="crossattention.q_attn.bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col)
]
@staticmethod
def attn_out() -> List:
return [
Row_Layer(weight="attn.c_proj.weight",
bias="attn.c_proj.bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row),
Row_Layer(weight="crossattention.c_proj.weight",
bias="crossattention.c_proj.bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Row)
]
@staticmethod
def mlp_in() -> List:
return [
Col_Layer(weight="mlp.c_fc.weight", bias="mlp.c_fc.bias", reversed=True, replace_layer=col_nn.Linear1D_Col),
]
@staticmethod
def mlp_out() -> List:
return [
Row_Layer(weight="mlp.c_proj.weight",
bias="mlp.c_proj.bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row)
]
@staticmethod
def embedding() -> List:
return [Col_Layer(weight="wte.weight", replace_layer=col_nn.VocabParallelEmbedding1D)]
from transformers import GPT2LMHeadModel
class GPT2LMHeadModelPolicy(GPT2Policy):
@staticmethod
def argument_policy(config, world_size):
base_argument = GPT2Policy.argument_policy(config, world_size)
argument = {
GPT2LMHeadModel: Argument(attr_dict={}, param_funcs=[
GPT2LMHeadModelPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument
@staticmethod
def unembedding() -> List:
return [
Col_Layer(weight="lm_head.weight",
bias="lm_head.bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True)
]
......@@ -2,6 +2,7 @@ from typing import Any, Callable, Dict, List
import torch
import torch.nn as nn
from transformers.pytorch_utils import Conv1D
from ..policies.autopolicy import get_autopolicy
from ..policies.basepolicy import Policy
......@@ -35,10 +36,22 @@ class ModelSharder(object):
self.model_config = self.model.config
def shard(self) -> None:
self.reshape_embedding()
self.inject_model(self.model)
self.replace_layer(self.model)
self.bind_layer(self.model)
def reshape_embedding(self,) -> None:
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
vocab_size = self.model_config.vocab_size
world_size = self.shard_config.world_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.model_config = self.model.config
def inject_model(
self,
model: nn.Module,
......@@ -53,6 +66,8 @@ class ModelSharder(object):
"""
inject_policy = self.policy.inject_policy()
if inject_policy is None:
return
org_model_cls = inject_policy[0]
shard_model_cls = inject_policy[1]
......@@ -82,9 +97,9 @@ class ModelSharder(object):
origin_layer_cls = argument_policy[0]
attr_dict = argument_policy[1].attr_dict
param_funcs = argument_policy[1].param_funcs
self.reverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
self.traverse_replace_layer(model, origin_layer_cls, attr_dict, param_funcs)
def reverse_replace_layer(
def traverse_replace_layer(
self,
layer: nn.Module,
origin_cls: nn.Module,
......@@ -100,17 +115,12 @@ class ModelSharder(object):
attr_dict (Dict): The attribute dict to modify
policy_cls (:class:`Policy`): The policy class
"""
for name, child in layer.named_children():
if child.__class__ == origin_cls:
# replac_layer = child
if layer.__class__ == origin_cls:
for k, v in attr_dict.items():
setattr_(child, k, v, ignore=True)
# print(f"Sharding {name} layer", replac_layer.attention.self.__dict__)
# setattr_(layer, name, self.shard_one_layer(child, policy_cls))
self.shard_one_layer(child, param_funcs)
continue
self.reverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
setattr_(layer, k, v, ignore=True)
self.shard_one_layer(layer, param_funcs)
for name, child in layer.named_children():
self.traverse_replace_layer(child, origin_cls, attr_dict, param_funcs)
return layer
def shard_one_layer(
......@@ -126,7 +136,6 @@ class ModelSharder(object):
param_funcs (:class:`List[typing.Callable]`): The function list to get shard information in policy class
"""
# print(org_layer)
for func in param_funcs:
policy_layers = func()
for policy_layer in policy_layers:
......@@ -136,9 +145,10 @@ class ModelSharder(object):
bias_attr = policy_layer.bias
replace_layer_cls = policy_layer.replace_layer
ignore = policy_layer.ignore
n_cast = policy_layer.n_cast
reversed = policy_layer.reversed
if policy_layer.__class__.__name__ == "Col_Layer":
gather_output = policy_layer.gather_output
# print(gather_output)
if weight_attr is not None:
if hasattr_(org_layer, weight_attr):
......@@ -161,13 +171,11 @@ class ModelSharder(object):
layer_attr = (lambda x: x[:x.rfind(".")])(weight_attr or bias_attr)
# slice weight and bias
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__)
# print(os.environ['RANK'], policy_layer.__class__, weight.shape, bias.shape if bias is not None else None)
weight, bias = self.slicer.slice_weight_bias(weight, bias, policy_layer.__class__, n_cast, reversed)
# create new object to replace the origin layer
if replace_layer_cls is not None:
# print(f"RANK {os.environ['RANK']}: replace {getattr_(org_layer, layer_attr).__class__} to {replace_layer_cls}, shape is {weight.shape}")
if isinstance(getattr_(org_layer, layer_attr), nn.Linear):
if isinstance(getattr_(org_layer, layer_attr), (nn.Linear, Conv1D)):
if replace_layer_cls.__name__ == "Linear1D_Row":
replace_layer = replace_layer_cls(weight.shape[1],
weight.shape[0],
......@@ -235,6 +243,8 @@ class ModelSharder(object):
model (:class:`torch.nn.Module`): The shard model
"""
binding_map = self.policy.binding_policy()
if binding_map is None:
return
for k, v in binding_map.items():
param = getattr_(model, k)
param = nn.Parameter(param)
......
......@@ -19,6 +19,8 @@ class Slicer():
weight: torch.Tensor,
bias: torch.Tensor,
policy_layer_cls: Layer,
n_cast: int = None,
reversed: bool = False,
):
r"""
Slice the weight and bias according to policy layer cls
......@@ -33,13 +35,18 @@ class Slicer():
"""
if policy_layer_cls == Layer:
return weight, bias
elif policy_layer_cls == Col_Layer:
weight = self.slice_tensor(weight, 1, False)
dim = dim_mapping[policy_layer_cls] if not reversed else (1 - dim_mapping[policy_layer_cls])
# print(weight.shape, dim)
if policy_layer_cls == Col_Layer:
weight = self.slice_tensor(weight, dim, False, n_cast)
bias = self.slice_tensor(bias, 0, True)
elif policy_layer_cls == Row_Layer:
weight = self.slice_tensor(weight, 0, False)
weight = self.slice_tensor(weight, dim, False, n_cast)
else:
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
if reversed:
weight = weight.transpose(0, 1).contiguous()
return weight, bias
def slice_tensor(
......@@ -47,6 +54,7 @@ class Slicer():
tensor_in: torch.Tensor,
dim: int,
is_bias: bool,
n_cast: int = None,
) -> torch.Tensor:
r"""
Slice tensor according to the config
......@@ -59,14 +67,15 @@ class Slicer():
if tensor_in is None:
return None
if not is_bias:
return self.slice_2d(tensor_in, dim)
return self.slice_2d(tensor_in, dim, n_cast)
else:
return self.slice_1d(tensor_in)
return self.slice_1d(tensor_in, n_cast)
def slice_2d(
self,
tensor: torch.Tensor,
dim: int,
n_cast: int = None,
) -> torch.Tensor:
r"""
Slice the 2D tensor
......@@ -77,13 +86,14 @@ class Slicer():
"""
assert dim in [0, 1], f"Only support 2D tensor, but got {dim}D tensor"
if dim == 0:
return self.slice_row(tensor)
return self.slice_row(tensor, n_cast)
elif dim == 1:
return self.slice_col(tensor)
return self.slice_col(tensor, n_cast)
def slice_1d(
self,
tensor: torch.Tensor,
n_cast: int = None,
) -> torch.Tensor:
r"""
Slice the 1D tensor
......@@ -94,11 +104,19 @@ class Slicer():
Returns:
:class:`torch.Tensor`: The sliced tensor
"""
if n_cast is None:
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
else:
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
chunk_list = [
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
]
return torch.cat(chunk_list, dim=0).contiguous()
def slice_col(
self,
tensor: torch.Tensor,
n_cast: int = None,
) -> torch.Tensor:
r"""
Slice the tensor in column
......@@ -110,11 +128,19 @@ class Slicer():
:class:`torch.Tensor`: The sliced tensor
"""
if n_cast is None:
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
else:
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
chunk_list = [
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
]
return torch.cat(chunk_list, dim=0).contiguous()
def slice_row(
self,
tensor: torch.Tensor,
n_cast: int = None,
) -> torch.Tensor:
r"""
Slice the tensor in column
......@@ -125,4 +151,11 @@ class Slicer():
Returns:
:class:`torch.Tensor`: The sliced tensor
"""
if n_cast is None:
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
else:
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
chunk_list = [
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
]
return torch.cat(chunk_list, dim=1).contiguous()
......@@ -6,24 +6,28 @@ import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, get_scheduler
from transformers import AutoTokenizer, BertForMaskedLM, DataCollatorForLanguageModeling, GPT2LMHeadModel, get_scheduler
import colossalai
from colossalai.shardformer.shard import ShardConfig, shard_model
from colossalai.utils import get_current_device, print_rank_0
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def get_args():
parser = colossalai.get_default_parser()
parser.add_argument("--mode", type=str, default='inference')
parser.add_argument("--save_model", action='store_true')
parser.add_argument("--model", type=str, default='bert-base-uncased')
return parser.parse_args()
def load_data():
def load_data(args):
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
# tokenizer.pad_token_id = 0
datasets = load_dataset('wikitext', 'wikitext-2-raw-v1')
# datasets=load_dataset("yelp_review_full")
tokenized_datasets = datasets.map(
......@@ -42,18 +46,23 @@ def load_data():
def inference(model: nn.Module, args):
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
print(model)
# print(model.wte.weight.shape)
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
tokenizer.pad_token_id = 0
token = "Hello, my dog is cute"
inputs = tokenizer(token, return_tensors="pt")
inputs.to("cuda")
model.eval()
model.to("cuda")
outputs = model(**inputs)
print(outputs)
print(outputs[0])
def train(model: nn.Module, args, num_epoch: int = 3):
train_dataloader, eval_dataloader = load_data()
train_dataloader, eval_dataloader = load_data(args)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
num_training = num_epoch * len(train_dataloader)
progress_bar = tqdm(range(num_training))
......@@ -94,8 +103,13 @@ def train(model: nn.Module, args, num_epoch: int = 3):
if __name__ == "__main__":
args = get_args()
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
colossalai.launch_from_torch(config=args.config)
if args.model == 'bert-base-uncased':
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
elif args.model == 'gpt2':
model = GPT2LMHeadModel.from_pretrained("gpt2")
else:
raise AttributeError("model not supported")
shard_config = ShardConfig(
rank=int(str(get_current_device()).split(':')[-1]),
world_size=int(os.environ['WORLD_SIZE']),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment