"vscode:/vscode.git/clone" did not exist on "61da3fbc524c8c7939d194007d91488b89288dc5"
Unverified Commit 7172459e authored by Wenhao Chen's avatar Wenhao Chen Committed by GitHub
Browse files

[shardformer]: support gpt-j, falcon, Mistral and add interleaved pipeline for bert (#5088)



* [shardformer] implement policy for all GPT-J models and test

* [shardformer] support interleaved pipeline parallel for bert finetune

* [shardformer] shardformer support falcon (#4883)

* [shardformer]: fix interleaved pipeline for bert model (#5048)

* [hotfix]: disable seq parallel for gptj and falcon, and polish code (#5093)

* Add Mistral support for Shardformer (#5103)

* [shardformer] add tests to mistral (#5105)

---------
Co-authored-by: default avatarPengtai Xu <henryxu880@gmail.com>
Co-authored-by: default avatarppt0011 <143150326+ppt0011@users.noreply.github.com>
Co-authored-by: default avatarflybird11111 <1829166702@qq.com>
Co-authored-by: default avatareric8607242 <e0928021388@gmail.com>
parent 126cf180
from collections import namedtuple
import psutil
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
_GLOBAL_CUDA_MEM_FRACTION = 1.0
_GLOBAL_CPU_MEM_CAPACITY = -1
# copy from PatrickStar
def _get_cpu_memory_info():
ps_mem_info = namedtuple("ps_mem_info", ["total", "free", "cached", "buffers", "used"])
try:
# psutil reads the memory info from /proc/memory_info,
# which results in returning the host memory instead of
# that of container.
# Here we try to read the container memory with method in:
# https://stackoverflow.com/a/46213331/5163915
mems = {}
with open("/sys/fs/cgroup/memory/memory.meminfo", "rb") as f:
for line in f:
fields = line.split()
mems[fields[0]] = int(fields[1]) * 1024
total = mems[b"MemTotal:"]
free = mems[b"MemFree:"]
cached = mems[b"Cached:"]
buffers = mems[b"Buffers:"]
used = total - free - cached - buffers
if used < 0:
used = total - free
mem_info = ps_mem_info(total=total, free=free, cached=cached, buffers=buffers, used=used)
except FileNotFoundError:
mems = psutil.virtual_memory()
mem_info = ps_mem_info(
total=mems.total,
free=mems.free,
cached=mems.cached,
buffers=mems.buffers,
used=mems.used,
)
return mem_info
def colo_device_memory_capacity(device: torch.device) -> int:
"""
Get the capacity of the memory of the device
Args:
device (torch.device): a device
Returns:
int: size in byte
"""
# TODO: add NPU support
assert isinstance(device, torch.device)
if device.type == "cpu":
# In the context of 1-CPU-N-GPU, the memory capacity of the current process is 1/N overall CPU memory.
return colo_get_cpu_memory_capacity() // dist.get_world_size()
if device.type == "cuda":
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
def colo_get_cpu_memory_capacity() -> int:
"""
Get the cpu memory capacity. We may not use all of it.
Returns:
int: _description_
"""
global _GLOBAL_CPU_MEM_CAPACITY
if _GLOBAL_CPU_MEM_CAPACITY == -1:
mem_info = _get_cpu_memory_info()
return mem_info.total
else:
return _GLOBAL_CPU_MEM_CAPACITY
from .gemini import ( from .gemini import GeminiAdamOptimizer, GeminiDDP, GeminiOptimizer, get_static_torch_model
ColoInitContext,
GeminiAdamOptimizer,
GeminiDDP,
GeminiOptimizer,
get_static_torch_model,
post_process_colo_init_ctx,
)
from .low_level import LowLevelZeroOptimizer from .low_level import LowLevelZeroOptimizer
from .wrapper import zero_model_wrapper, zero_optim_wrapper from .wrapper import zero_model_wrapper, zero_optim_wrapper
...@@ -16,7 +9,5 @@ __all__ = [ ...@@ -16,7 +9,5 @@ __all__ = [
"zero_model_wrapper", "zero_model_wrapper",
"zero_optim_wrapper", "zero_optim_wrapper",
"LowLevelZeroOptimizer", "LowLevelZeroOptimizer",
"ColoInitContext",
"post_process_colo_init_ctx",
"get_static_torch_model", "get_static_torch_model",
] ]
from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration from .chunk import ChunkManager, TensorInfo, TensorState, search_chunk_configuration
from .colo_init_context import ColoInitContext, post_process_colo_init_ctx
from .gemini_ddp import GeminiDDP from .gemini_ddp import GeminiDDP
from .gemini_mgr import GeminiManager from .gemini_mgr import GeminiManager
from .gemini_optimizer import GeminiAdamOptimizer, GeminiOptimizer from .gemini_optimizer import GeminiAdamOptimizer, GeminiOptimizer
...@@ -15,6 +14,4 @@ __all__ = [ ...@@ -15,6 +14,4 @@ __all__ = [
"get_static_torch_model", "get_static_torch_model",
"GeminiAdamOptimizer", "GeminiAdamOptimizer",
"GeminiOptimizer", "GeminiOptimizer",
"ColoInitContext",
"post_process_colo_init_ctx",
] ]
...@@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type ...@@ -6,8 +6,8 @@ from typing import Dict, List, Optional, Tuple, Type
import torch import torch
from colossalai.legacy.utils.memory import colo_device_memory_capacity
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.gemini.chunk import Chunk from colossalai.zero.gemini.chunk import Chunk
from .chunk import Chunk, ChunkManager from .chunk import Chunk, ChunkManager
......
...@@ -178,6 +178,18 @@ Model/Feature Compatibility Matrix: ...@@ -178,6 +178,18 @@ Model/Feature Compatibility Matrix:
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
</tr> </tr>
<tr>
<td nowrap="nowrap">Falcon</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr> <tr>
<td colspan="39"></td> <td colspan="39"></td>
</tr> </tr>
......
...@@ -174,6 +174,18 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github. ...@@ -174,6 +174,18 @@ Author: [Baizhou Zhang](https://github.com/Fridge003), [Bin Jia](https://github.
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td> <td nowrap="nowrap" align="center"></td>
</tr> </tr>
<tr>
<td nowrap="nowrap">Falcon</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center">✔️</td>
<td nowrap="nowrap" align="center"></td>
<td nowrap="nowrap" align="center"></td>
</tr>
<tr> <tr>
<td colspan="39"></td> <td colspan="39"></td>
</tr> </tr>
......
...@@ -88,20 +88,24 @@ class GLUEDataBuilder: ...@@ -88,20 +88,24 @@ class GLUEDataBuilder:
) )
def val_dataloader(self): def val_dataloader(self):
# TODO: drop_last is set to True for now to avoid error when using PP
# as the last batch may not be divisible by the number of microbatches
if len(self.eval_splits) == 1: if len(self.eval_splits) == 1:
return self.plugin.prepare_dataloader(self.dataset["validation"], batch_size=self.eval_batch_size) return self.plugin.prepare_dataloader(
self.dataset["validation"], batch_size=self.eval_batch_size, drop_last=True
)
elif len(self.eval_splits) > 1: elif len(self.eval_splits) > 1:
return [ return [
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True)
for x in self.eval_splits for x in self.eval_splits
] ]
def test_dataloader(self): def test_dataloader(self):
if len(self.eval_splits) == 1: if len(self.eval_splits) == 1:
return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size) return self.plugin.prepare_dataloader(self.dataset["test"], batch_size=self.eval_batch_size, drop_last=True)
elif len(self.eval_splits) > 1: elif len(self.eval_splits) > 1:
return [ return [
self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size) self.plugin.prepare_dataloader(self.dataset[x], batch_size=self.eval_batch_size, drop_last=True)
for x in self.eval_splits for x in self.eval_splits
] ]
......
...@@ -57,7 +57,9 @@ def evaluate_model( ...@@ -57,7 +57,9 @@ def evaluate_model(
def evaluate_subset(dataloader: DataLoader): def evaluate_subset(dataloader: DataLoader):
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(
None if not booster.plugin.stage_manager.is_interleave else -1
)
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_current_device())
for batch in dataloader: for batch in dataloader:
...@@ -69,9 +71,10 @@ def evaluate_model( ...@@ -69,9 +71,10 @@ def evaluate_model(
current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group) current_pp_group_ranks = pg_mesh.get_ranks_in_group(pp_group)
current_rank = dist.get_rank() current_rank = dist.get_rank()
batch = iter([batch]) batch = iter([batch])
outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True) outputs = booster.execute_pipeline(batch, model, criterion, return_loss=True, return_outputs=True)
if is_pp_last_stage: if is_pp_last_device:
logits = outputs["outputs"]["logits"] logits = outputs["outputs"]["logits"]
val_loss = outputs["loss"] val_loss = outputs["loss"]
accum_loss.add_(val_loss) accum_loss.add_(val_loss)
...@@ -133,8 +136,10 @@ def train_epoch( ...@@ -133,8 +136,10 @@ def train_epoch(
coordinator: DistCoordinator, coordinator: DistCoordinator,
): ):
use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1 use_pipeline = isinstance(booster.plugin, HybridParallelPlugin) and booster.plugin.pp_size > 1
is_pp_last_stage = use_pipeline and booster.plugin.stage_manager.is_last_stage() is_pp_last_device = use_pipeline and booster.plugin.stage_manager.is_last_stage(
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_stage) None if not booster.plugin.stage_manager.is_interleave else -1
)
print_flag = (not use_pipeline and coordinator.is_master()) or (use_pipeline and is_pp_last_device)
total_step = len(train_dataloader) total_step = len(train_dataloader)
model.train() model.train()
...@@ -148,7 +153,7 @@ def train_epoch( ...@@ -148,7 +153,7 @@ def train_epoch(
train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True train_dataloader_iter, model, _criterion, optimizer, return_loss=True, return_outputs=True
) )
# Backward and optimize # Backward and optimize
if is_pp_last_stage: if is_pp_last_device:
loss = outputs["loss"] loss = outputs["loss"]
pbar.set_postfix({"loss": loss.item()}) pbar.set_postfix({"loss": loss.item()})
else: else:
...@@ -222,7 +227,9 @@ def main(): ...@@ -222,7 +227,9 @@ def main():
tp_size=1, tp_size=1,
pp_size=2, pp_size=2,
num_microbatches=None, num_microbatches=None,
microbatch_size=1, pp_style="interleaved",
num_model_chunks=2,
microbatch_size=16,
enable_all_optimization=True, enable_all_optimization=True,
zero_stage=1, zero_stage=1,
precision="fp16", precision="fp16",
......
...@@ -71,6 +71,10 @@ class ModelZooRegistry(dict): ...@@ -71,6 +71,10 @@ class ModelZooRegistry(dict):
new_dict = dict() new_dict = dict()
for k, v in self.items(): for k, v in self.items():
if keyword == "transformers_gpt":
if keyword in k and not "gptj" in k: # ensure GPT2 does not retrieve GPTJ models
new_dict[k] = v
else:
if keyword in k: if keyword in k:
new_dict[k] = v new_dict[k] = v
......
...@@ -3,10 +3,17 @@ from .bert import * ...@@ -3,10 +3,17 @@ from .bert import *
from .blip2 import * from .blip2 import *
from .bloom import * from .bloom import *
from .chatglm2 import * from .chatglm2 import *
from .falcon import *
from .gpt import * from .gpt import *
from .gptj import *
from .llama import * from .llama import *
from .opt import * from .opt import *
from .sam import * from .sam import *
from .t5 import * from .t5 import *
from .vit import * from .vit import *
from .whisper import * from .whisper import *
try:
from .mistral import *
except ImportError:
print("This version of transformers doesn't support mistral.")
import torch
import transformers
from ..registry import ModelAttribute, model_zoo
# ===============================
# Register Falcon
# ===============================
def data_gen():
# Generated from following code snippet
#
# from transformers import AutoTokenizer
# input = 'Hello, my dog is cute'
# tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data["labels"] = data["input_ids"].clone()
return data
def data_gen_for_token_classification():
# token classification data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
data["labels"] = torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.int64)
return data
def data_gen_for_sequence_classification():
# sequence classification data gen
data = data_gen()
data["labels"] = torch.tensor([0], dtype=torch.int64)
return data
def data_gen_for_question_answering():
input_ids = torch.tensor(
[[57647, 1620, 23967, 620, 107373, 34, 91514, 620, 107373, 1620, 267, 35378, 48946, 18161, 48946, 18161]],
dtype=torch.int64,
)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
start_positions = torch.tensor([1], dtype=torch.int64)
end_positions = torch.tensor([10], dtype=torch.int64)
return dict(
input_ids=input_ids, attention_mask=attention_mask, start_positions=start_positions, end_positions=end_positions
)
# define output transform function
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_falcon_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
loss_fn_for_causal_lm = lambda x: x.loss
loss_fn_for_classification = lambda x: x.loss
loss_fn_for_question_answering = lambda x: x.loss
config = transformers.FalconConfig(
num_hidden_layers=2,
num_attention_heads=4,
vocab_size=250880,
hidden_dropout=0,
attention_dropout=0,
hidden_size=64,
multi_query=False,
new_decoder_architecture=True,
pad_token_id=-1,
)
model_zoo.register(
name="transformers_falcon",
model_fn=lambda: transformers.FalconModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_falcon_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_falcon_for_causal_lm",
model_fn=lambda: transformers.FalconForCausalLM(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_falcon_for_sequence_classification",
model_fn=lambda: transformers.FalconForSequenceClassification(config),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_classification,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_falcon_for_token_classification",
model_fn=lambda: transformers.FalconForTokenClassification(config),
data_gen_fn=data_gen_for_token_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_classification,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_falcon_for_question_answering",
model_fn=lambda: transformers.FalconForQuestionAnswering(config),
data_gen_fn=data_gen_for_question_answering,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_question_answering,
model_attribute=ModelAttribute(has_control_flow=True),
)
...@@ -14,7 +14,7 @@ def data_gen(): ...@@ -14,7 +14,7 @@ def data_gen():
# Generated from following code snippet # Generated from following code snippet
# #
# from transformers import GPT2Tokenizer # from transformers import GPT2Tokenizer
# input = 'Hello, my dog is cute' # input = 'Hello, my dog is cute is cute' (last two words repeated to satisfy length requirement)
# tokenized_input = tokenizer(input, return_tensors='pt') # tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids'] # input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask'] # attention_mask = tokenized_input['attention_mask']
......
import copy
import torch
import transformers
from ..registry import ModelAttribute, model_zoo
# ===============================
# Register single-sentence GPT
# ===============================
def data_gen():
# Generated from following code snippet
#
# from transformers import AutoTokenizer
# input = 'Hello, my dog is cute is cute' (last two words repeated to satisfy length requirement)
# tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
# tokenized_input = tokenizer(input, return_tensors='pt')
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[15496, 11, 616, 3290, 318, 13779, 318, 13779]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data["labels"] = data["input_ids"].clone()
return data
def data_gen_for_question_answering():
# question answering data gen
# `labels` is the type not the token id for token classification, 0 or 1
data = data_gen()
start_positions = torch.tensor([0], dtype=torch.int64)
data["start_positions"] = start_positions
end_positions = torch.tensor([1], dtype=torch.int64)
data["end_positions"] = end_positions
return data
def data_gen_for_sequence_classification():
# sequence classification data gen
data = data_gen()
data["labels"] = torch.tensor([1], dtype=torch.int64)
return data
# define output transform function
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_gptj_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
loss_fn = lambda x: x.loss
config = transformers.GPTJConfig(
n_layer=2,
n_head=16,
vocab_size=50258,
attn_pdrop=0,
embd_pdrop=0,
resid_pdrop=0,
hidden_dropout=0,
problem_type="single_label_classification",
pad_token_id=50256,
)
config_for_token_classification = copy.deepcopy(config)
config_for_token_classification.num_labels = 2
# register the following models
model_zoo.register(
name="transformers_gptj",
model_fn=lambda: transformers.GPTJModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_gptj_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_gptj_lm",
model_fn=lambda: transformers.GPTJForCausalLM(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_gptj_for_question_answering",
model_fn=lambda: transformers.GPTJForQuestionAnswering(config),
data_gen_fn=data_gen_for_question_answering,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_gptj_for_sequence_classification",
model_fn=lambda: transformers.GPTJForSequenceClassification(config_for_token_classification),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
import torch
import transformers
from transformers import MistralConfig
from ..registry import ModelAttribute, model_zoo
# ===============================
# Register single-sentence Mistral
# ===============================
def data_gen():
# Generated from following code snippet
#
# from transformers import AutoModelForCausalLM, AutoTokenizer
# tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1")
# input = 'My favourite condiment is vinegar' (last two words repeated to satisfy length requirement)
# tokenized_input = tokenizer([input], return_tensors="pt")
# input_ids = tokenized_input['input_ids']
# attention_mask = tokenized_input['attention_mask']
input_ids = torch.tensor([[1, 1984, 16020, 2076, 2487, 349, 21375, 4749]], dtype=torch.int64)
attention_mask = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.int64)
return dict(input_ids=input_ids, attention_mask=attention_mask)
def data_gen_for_lm():
# LM data gen
# the `labels` of LM is the token of the output, cause no padding, use `input_ids` as `labels`
data = data_gen()
data["labels"] = data["input_ids"].clone()
return data
def data_gen_for_sequence_classification():
# sequence classification data gen
data = data_gen()
data["labels"] = torch.tensor([1], dtype=torch.int64)
return data
# define output transform function
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_mistral_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
)
loss_fn = lambda x: x.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
config = MistralConfig(
hidden_size=256, intermediate_size=256, num_attention_heads=64, num_hidden_layers=2, vocab_size=50258
)
model_zoo.register(
name="transformers_mistral",
model_fn=lambda: transformers.MistralModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_mistral_model,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_mistral_for_casual_lm",
model_fn=lambda: transformers.MistralForCausalLM(config),
data_gen_fn=data_gen_for_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_mistral_for_sequence_classification",
model_fn=lambda: transformers.MistralForSequenceClassification(config),
data_gen_fn=data_gen_for_sequence_classification,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_seq_classification,
model_attribute=ModelAttribute(has_control_flow=True),
)
...@@ -105,6 +105,11 @@ def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool ...@@ -105,6 +105,11 @@ def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool
"transformers_sam", "transformers_sam",
"transformers_vit", "transformers_vit",
"transformers_gpt_double_heads", # TODO check why does the model fail to run using Gemini "transformers_gpt_double_heads", # TODO check why does the model fail to run using Gemini
"transformers_falcon", # TODO check why falcon fails to run Gemini
"transformers_falcon_for_causal_lm",
"transformers_falcon_for_sequence_classification",
"transformers_falcon_for_token_classification",
"transformers_falcon_for_question_answering",
]: ]:
continue continue
......
...@@ -4,6 +4,7 @@ from types import MethodType ...@@ -4,6 +4,7 @@ from types import MethodType
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import colossalai import colossalai
...@@ -11,31 +12,21 @@ from colossalai.cluster import ProcessGroupMesh ...@@ -11,31 +12,21 @@ from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import OptimizerWrapper from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
NUM_LAYER = 8
DIM = 4
class MlpModel(nn.Module): class MlpModel(nn.Module):
def __init__(self): def __init__(self):
super(MlpModel, self).__init__() super().__init__()
self.linear1 = nn.Linear(4, 8) self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)])
self.linear2 = nn.Linear(8, 8)
self.linear3 = nn.Linear(8, 8)
self.linear4 = nn.Linear(8, 8)
self.linear5 = nn.Linear(8, 8)
self.linear6 = nn.Linear(8, 8)
self.linear7 = nn.Linear(8, 8)
self.linear8 = nn.Linear(8, 4)
def forward(self, x): def forward(self, x):
x = self.linear1(x) for layer in self.layers:
x = self.linear2(x) x = layer(x)
x = self.linear3(x)
x = self.linear4(x)
x = self.linear5(x)
x = self.linear6(x)
x = self.linear7(x)
x = self.linear8(x)
return x return x
...@@ -44,70 +35,71 @@ def pp_linear_fwd( ...@@ -44,70 +35,71 @@ def pp_linear_fwd(
data: torch.Tensor = None, data: torch.Tensor = None,
input_obj: torch.Tensor = None, input_obj: torch.Tensor = None,
stage_mgr: PipelineStageManager = None, stage_mgr: PipelineStageManager = None,
num_chunks: int = None,
model_chunk_id: int = None, model_chunk_id: int = None,
): ):
if stage_mgr.is_first_stage() and model_chunk_id == 0: if stage_mgr.is_first_stage(model_chunk_id):
return {"input_obj": forward(data)} return {"input_obj": forward(data)}
elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1: elif stage_mgr.is_last_stage(model_chunk_id):
return forward(input_obj) return forward(input_obj)
else: else:
return {"input_obj": forward(input_obj)} return {"input_obj": forward(input_obj)}
@parameterize("num_micro_batches", [4, 8, 12]) def run_pp(
def examine_pp(num_micro_batches): rank: int,
world_size: int,
port: int,
num_microbatch: int,
batch_size: int,
num_model_chunk: int,
):
""" """
This test is to examine the correctness of interleaved 1F1B, compared with torch. This test is to examine the correctness of interleaved 1F1B, compared with torch.
Be aware it contains some hardcodes. Be aware it contains some hardcodes.
""" """
world_size = torch.distributed.get_world_size() colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
local_rank = torch.distributed.get_rank()
seed_all(1453)
NUM_MICRO_BATCHS = num_micro_batches
BATCH_SIZE = num_micro_batches
NUM_CHUNKS = 2
# create model # create model
seed_all(1453)
torch_model = MlpModel().cuda() torch_model = MlpModel().cuda()
pp_model = copy.deepcopy(torch_model).cuda() pp_model = copy.deepcopy(torch_model).cuda()
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 pg_mesh = ProcessGroupMesh(world_size)
pg_mesh = ProcessGroupMesh(1, world_size, 1) stage_manager = PipelineStageManager(
stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True) pg_mesh, pipeline_axis=0, enable_interleave=True, num_model_chunks=num_model_chunk
schedule = InterleavedSchedule(NUM_MICRO_BATCHS, NUM_CHUNKS, stage_manager) )
schedule = InterleavedSchedule(
stage_manager=stage_manager,
num_model_chunks=num_model_chunk,
num_microbatch=num_microbatch,
)
sharded_model = torch.nn.ModuleList() sharded_model = torch.nn.ModuleList()
for idx, (_, sub_model) in enumerate(pp_model.named_children()): for idx, sub_model in enumerate(pp_model.layers):
if idx % (world_size) == local_rank: if idx % world_size == rank:
sub_model._forward = sub_model.forward sub_model._forward = sub_model.forward
sub_model.forward = MethodType( sub_model.forward = MethodType(
partial( partial(pp_linear_fwd, stage_mgr=stage_manager, model_chunk_id=len(sharded_model)),
pp_linear_fwd, stage_mgr=stage_manager, num_chunks=NUM_CHUNKS, model_chunk_id=len(sharded_model)
),
sub_model._forward, sub_model._forward,
) )
sharded_model.append(sub_model.cuda()) sharded_model.append(sub_model.cuda())
assert len(sharded_model) == num_model_chunk, "num_model_chunk is not correct"
# create optimizer # create optimizer
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1e-5)
pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1)) pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1e-5))
# create # create data
seed_all(1453) seed_all(115)
if local_rank == 0: input_list = [torch.rand(batch_size, DIM).cuda()]
input_list = [torch.rand(BATCH_SIZE, 4).cuda()] dist.all_reduce(input_list[0])
else:
input_list = [torch.zeros(BATCH_SIZE, 4).cuda()]
torch.distributed.all_reduce(input_list[0])
criterion = lambda x, y: torch.mean(x) def criterion(x, *args, **kwargs):
return (x * x).mean()
# forward and backward # forward and backward
torch_output = torch_model(input_list[0]) torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output, _) torch_loss = criterion(torch_output)
torch_loss.backward() torch_loss.backward()
pp_ret = schedule.forward_backward_step( pp_ret = schedule.forward_backward_step(
...@@ -115,45 +107,41 @@ def examine_pp(num_micro_batches): ...@@ -115,45 +107,41 @@ def examine_pp(num_micro_batches):
) )
# check loss # check loss
if stage_manager.is_last_stage(): if stage_manager.is_last_stage(-1):
assert torch.allclose(torch_loss, pp_ret["loss"]) assert torch.allclose(torch_loss, pp_ret["loss"])
# check gradients # check gradients
torch_grad = [] for i in range(num_model_chunk):
for torch_p in torch_model.parameters(): idx = world_size * i + rank
torch_grad.append(torch_p.grad.data) assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
for idx, pp_p in enumerate(sharded_model.parameters()):
if idx < 2:
assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data)
else:
assert torch.allclose(torch_grad[idx + local_rank * 2 + 6], pp_p.grad.data)
# step # step
torch_optimizer.step() torch_optimizer.step()
pp_optimizer.step() pp_optimizer.step()
# check updated param # check updated param
torch_param = [] for i in range(num_model_chunk):
for torch_p in torch_model.parameters(): idx = world_size * i + rank
torch_param.append(torch_p.data) assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
for idx, pp_p in enumerate(sharded_model.parameters()): assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
if idx < 2:
assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data)
else:
assert torch.allclose(torch_param[idx + local_rank * 2 + 6], pp_p.data)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
examine_pp()
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4, 12])
@pytest.mark.parametrize("batch_size", [12])
@pytest.mark.parametrize("num_model_chunk", [2, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_pp(): def test_pp(num_microbatch: int, batch_size: int, num_model_chunk: int):
spawn(run_dist, 4) assert NUM_LAYER % num_model_chunk == 0
spawn(
run_pp,
nprocs=NUM_LAYER // num_model_chunk,
num_microbatch=num_microbatch,
batch_size=batch_size,
num_model_chunk=num_model_chunk,
)
if __name__ == "__main__": if __name__ == "__main__":
test_pp() test_pp(num_microbatch=4, batch_size=4, num_model_chunk=4)
import pytest
import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
falcon = unwrap_model(org_model, "FalconModel", "transformer")
sharded_falcon = unwrap_model(sharded_model, "FalconModel", "transformer")
row_layer_for_check = ["h[0].self_attention.query_key_value", "word_embeddings"]
col_layer_for_check = ["h[0].self_attention.dense"]
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32":
atol, rtol = 1e-6, 1e-5
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(
falcon, sharded_falcon, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
)
col_layer_grads = get_grad_tensors_for_check(
falcon, sharded_falcon, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "FalconModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
atol, rtol = 2e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_weight(falcon, sharded_falcon, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
},
{"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
{"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
],
)
def run_falcon_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_falcon")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
],
)
def run_falcon_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_falcon")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def check_falcon(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_falcon_test()
def check_falcon_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_falcon_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_falcon():
spawn(check_falcon, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_falcon_3d():
spawn(check_falcon_3d, 8)
if __name__ == "__main__":
test_falcon()
test_falcon_3d()
import pytest
import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
gptj = unwrap_model(org_model, "GPTJModel", "transformer")
sharded_gptj = unwrap_model(sharded_model, "GPTJModel", "transformer")
col_layer_for_check = ["h[0].attn.k_proj"]
row_layer_for_check = ["h[0].mlp.fc_out"] # use dim=0 for wte get_grad_tensors_for_check
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
col_layer_grads = get_grad_tensors_for_check(
gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
)
row_layer_grads = get_grad_tensors_for_check(
gptj, sharded_gptj, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "GPTJModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
atol, rtol = 5e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_weight(gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
#'use_lazy_init': True, GPTJ currently do not support lazy init; model training has issue even without sharding
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
#'use_lazy_init': True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
#'use_lazy_init': True,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
#'use_lazy_init': True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
#'use_lazy_init': True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
],
)
@clear_cache_before_run()
def run_gptj_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
],
)
@clear_cache_before_run()
def run_gptj_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
def check_gptj(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_gptj_test()
def check_gptj_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_gptj_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_gptj():
spawn(check_gptj, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_gptj_3d():
spawn(check_gptj_3d, 8)
if __name__ == "__main__":
test_gptj()
test_gptj_3d()
import os
import pytest
import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
mistral_model = unwrap_model(org_model, "MistralModel", "model")
shard_mistral_model = unwrap_model(sharded_model, "MistralModel", "model")
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
col_layer_for_check = ["layers[0].self_attn.o_proj"]
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32":
atol, rtol = 5e-5, 1e-4
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(
mistral_model,
shard_mistral_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False,
)
col_layer_grads = get_grad_tensors_for_check(
mistral_model,
shard_mistral_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "MistralModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_weight(
mistral_model,
shard_mistral_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 4,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
],
)
def run_mistral_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_mistral")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def check_mistral(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_mistral_test()
@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_mistral():
spawn(check_mistral, 4)
if __name__ == "__main__":
test_mistral()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment