Unverified Commit 27061426 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[gemini] improve compatibility and add static placement policy (#4479)

* [gemini] remove distributed-related part from colotensor (#4379)

* [gemini] remove process group dependency

* [gemini] remove tp part from colo tensor

* [gemini] patch inplace op

* [gemini] fix param op hook and update tests

* [test] remove useless tests

* [test] remove useless tests

* [misc] fix requirements

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [test] fix model zoo

* [misc] update requirements

* [gemini] refactor gemini optimizer and gemini ddp (#4398)

* [gemini] update optimizer interface

* [gemini] renaming gemini optimizer

* [gemini] refactor gemini ddp class

* [example] update gemini related example

* [example] update gemini related example

* [plugin] fix gemini plugin args

* [test] update gemini ckpt tests

* [gemini] fix checkpoint io

* [example] fix opt example requirements

* [example] fix opt example

* [example] fix opt example

* [example] fix opt example

* [gemini] add static placement policy (#4443)

* [gemini] add static placement policy

* [gemini] fix param offload

* [test] update gemini tests

* [plugin] update gemini plugin

* [plugin] update gemini plugin docstr

* [misc] fix flash attn requirement

* [test] fix gemini checkpoint io test

* [example] update resnet example result (#4457)

* [example] update bert example result (#4458)

* [doc] update gemini doc (#4468)

* [example] update gemini related examples (#4473)

* [example] update gpt example

* [example] update dreambooth example

* [example] update vit

* [example] update opt

* [example] update palm

* [example] update vit and opt benchmark

* [hotfix] fix bert in model zoo (#4480)

* [hotfix] fix bert in model zoo

* [test] remove chatglm gemini test

* [test] remove sam gemini test

* [test] remove vit gemini test

* [hotfix] fix opt tutorial example (#4497)

* [hotfix] fix opt tutorial example

* [hotfix] fix opt tutorial example
parent 285fe7ba
...@@ -2,9 +2,9 @@ import argparse ...@@ -2,9 +2,9 @@ import argparse
import hashlib import hashlib
import math import math
import os import os
import shutil
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import shutil
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -19,6 +19,8 @@ from tqdm.auto import tqdm ...@@ -19,6 +19,8 @@ from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig from transformers import AutoTokenizer, PretrainedConfig
import colossalai import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
...@@ -26,8 +28,6 @@ from colossalai.nn.optimizer import HybridAdam ...@@ -26,8 +28,6 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext from colossalai.zero import ColoInitContext
from colossalai.zero.gemini import get_static_torch_model from colossalai.zero.gemini import get_static_torch_model
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
disable_existing_loggers() disable_existing_loggers()
logger = get_dist_logger() logger = get_dist_logger()
...@@ -138,10 +138,10 @@ def parse_args(input_args=None): ...@@ -138,10 +138,10 @@ def parse_args(input_args=None):
" resolution"), " resolution"),
) )
parser.add_argument( parser.add_argument(
"--placement", "--offload_optim_frac",
type=str, type=float,
default="cpu", default=1.0,
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.", help="Fraction of optimizer states to be offloaded. Valid when using colossalai as dist plan.",
) )
parser.add_argument( parser.add_argument(
"--center_crop", "--center_crop",
...@@ -461,18 +461,17 @@ def main(args): ...@@ -461,18 +461,17 @@ def main(args):
revision=args.revision, revision=args.revision,
) )
if args.externel_unet_path is None: if args.externel_unet_path is None:
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet", subfolder="unet",
revision=args.revision, revision=args.revision,
low_cpu_mem_usage=False) low_cpu_mem_usage=False)
else: else:
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0]) logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path, unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
revision=args.revision, revision=args.revision,
low_cpu_mem_usage=False) low_cpu_mem_usage=False)
vae.requires_grad_(False) vae.requires_grad_(False)
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
...@@ -491,30 +490,31 @@ def main(args): ...@@ -491,30 +490,31 @@ def main(args):
if args.plugin.startswith('torch_ddp'): if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin() plugin = TorchDDPPlugin()
elif args.plugin == 'gemini': elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5) plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, strict_ddp_mode=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero': elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5) plugin = LowLevelZeroPlugin(initial_scale=2**5)
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)
# config optimizer for colossalai zero # config optimizer for colossalai zero
optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) optimizer = HybridAdam(unet.parameters(),
lr=args.learning_rate,
initial_scale=2**5,
clipping_norm=args.max_grad_norm)
# load noise_scheduler # load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
# prepare dataset # prepare dataset
logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0]) logger.info(f"Prepare dataset from {args.instance_data_dir}", ranks=[0])
train_dataset = DreamBoothDataset( train_dataset = DreamBoothDataset(instance_data_root=args.instance_data_dir,
instance_data_root=args.instance_data_dir, instance_prompt=args.instance_prompt,
instance_prompt=args.instance_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_prompt=args.class_prompt,
class_prompt=args.class_prompt, tokenizer=tokenizer,
tokenizer=tokenizer, size=args.resolution,
size=args.resolution, center_crop=args.center_crop,
center_crop=args.center_crop, test=args.test_run)
test=args.test_run
)
def collate_fn(examples): def collate_fn(examples):
input_ids = [example["instance_prompt_ids"] for example in examples] input_ids = [example["instance_prompt_ids"] for example in examples]
...@@ -690,6 +690,7 @@ def main(args): ...@@ -690,6 +690,7 @@ def main(args):
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
main(args) main(args)
...@@ -2,9 +2,9 @@ import argparse ...@@ -2,9 +2,9 @@ import argparse
import hashlib import hashlib
import math import math
import os import os
import shutil
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import shutil
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -21,6 +21,8 @@ from tqdm.auto import tqdm ...@@ -21,6 +21,8 @@ from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig from transformers import AutoTokenizer, PretrainedConfig
import colossalai import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
...@@ -28,8 +30,6 @@ from colossalai.nn.optimizer import HybridAdam ...@@ -28,8 +30,6 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer from colossalai.zero import ColoInitContext, GeminiAdamOptimizer
from colossalai.zero.gemini import get_static_torch_model from colossalai.zero.gemini import get_static_torch_model
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
disable_existing_loggers() disable_existing_loggers()
logger = get_dist_logger() logger = get_dist_logger()
...@@ -459,18 +459,17 @@ def main(args): ...@@ -459,18 +459,17 @@ def main(args):
revision=args.revision, revision=args.revision,
) )
if args.externel_unet_path is None: if args.externel_unet_path is None:
logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0]) logger.info(f"Loading UNet2DConditionModel from {args.pretrained_model_name_or_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet", subfolder="unet",
revision=args.revision, revision=args.revision,
low_cpu_mem_usage=False) low_cpu_mem_usage=False)
else: else:
logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0]) logger.info(f"Loading UNet2DConditionModel from {args.externel_unet_path}", ranks=[0])
unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path, unet = UNet2DConditionModel.from_pretrained(args.externel_unet_path,
revision=args.revision, revision=args.revision,
low_cpu_mem_usage=False) low_cpu_mem_usage=False)
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path,
subfolder="unet", subfolder="unet",
revision=args.revision, revision=args.revision,
...@@ -490,8 +489,7 @@ def main(args): ...@@ -490,8 +489,7 @@ def main(args):
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id] hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, lora_attn_procs[name] = LoRACrossAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
cross_attention_dim=cross_attention_dim)
unet.set_attn_processor(lora_attn_procs) unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors) lora_layers = AttnProcsLayers(unet.attn_processors)
...@@ -513,14 +511,17 @@ def main(args): ...@@ -513,14 +511,17 @@ def main(args):
if args.plugin.startswith('torch_ddp'): if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin() plugin = TorchDDPPlugin()
elif args.plugin == 'gemini': elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2 ** 5) plugin = GeminiPlugin(strict_ddp_mode=True, initial_scale=2**5)
elif args.plugin == 'low_level_zero': elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5) plugin = LowLevelZeroPlugin(initial_scale=2**5)
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)
# config optimizer for colossalai zero # config optimizer for colossalai zero
optimizer = HybridAdam(unet.parameters(), lr=args.learning_rate, initial_scale=2**5, clipping_norm=args.max_grad_norm) optimizer = HybridAdam(unet.parameters(),
lr=args.learning_rate,
initial_scale=2**5,
clipping_norm=args.max_grad_norm)
# load noise_scheduler # load noise_scheduler
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
...@@ -711,6 +712,7 @@ def main(args): ...@@ -711,6 +712,7 @@ def main(args):
if args.push_to_hub: if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
main(args) main(args)
...@@ -49,8 +49,8 @@ python eval.py -c ./ckpt-low_level_zero -e 80 ...@@ -49,8 +49,8 @@ python eval.py -c ./ckpt-low_level_zero -e 80
Expected accuracy performance will be: Expected accuracy performance will be:
| Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero | | Model | Single-GPU Baseline FP32 | Booster DDP with FP32 | Booster DDP with FP16 | Booster Low Level Zero | Booster Gemini |
| --------- | ------------------------ | --------------------- | --------------------- | ---------------------- | | --------- | ------------------------ | --------------------- | --------------------- | ---------------------- | -------------- |
| ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% | | ResNet-18 | 85.85% | 84.91% | 85.46% | 84.50% | 84.60% |
**Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`** **Note: the baseline is adapted from the [script](https://pytorch-tutorial.readthedocs.io/en/latest/tutorial/chapter03_intermediate/3_2_2_cnn_resnet_cifar10/) to use `torchvision.models.resnet18`**
...@@ -104,7 +104,7 @@ def main(): ...@@ -104,7 +104,7 @@ def main():
'--plugin', '--plugin',
type=str, type=str,
default='torch_ddp', default='torch_ddp',
choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero'], choices=['torch_ddp', 'torch_ddp_fp16', 'low_level_zero', 'gemini'],
help="plugin to use") help="plugin to use")
parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint") parser.add_argument('-r', '--resume', type=int, default=-1, help="resume from the epoch's checkpoint")
parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory") parser.add_argument('-c', '--checkpoint', type=str, default='./checkpoint', help="checkpoint directory")
...@@ -141,7 +141,7 @@ def main(): ...@@ -141,7 +141,7 @@ def main():
if args.plugin.startswith('torch_ddp'): if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin() plugin = TorchDDPPlugin()
elif args.plugin == 'gemini': elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) plugin = GeminiPlugin(initial_scale=2**5)
elif args.plugin == 'low_level_zero': elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5) plugin = LowLevelZeroPlugin(initial_scale=2**5)
......
import time import time
import torch import torch
import tqdm
import transformers import transformers
from args import parse_benchmark_args
from transformers import ViTConfig, ViTForImageClassification from transformers import ViTConfig, ViTForImageClassification
import tqdm
import colossalai import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import get_current_device
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from args import parse_benchmark_args
def format_num(num: int, bytes=False): def format_num(num: int, bytes=False):
"""Scale bytes to its proper format, e.g. 1253656 => '1.20MB'""" """Scale bytes to its proper format, e.g. 1253656 => '1.20MB'"""
...@@ -26,8 +25,13 @@ def format_num(num: int, bytes=False): ...@@ -26,8 +25,13 @@ def format_num(num: int, bytes=False):
def get_data(batch_size, num_labels, num_channels=3, height=224, width=224): def get_data(batch_size, num_labels, num_channels=3, height=224, width=224):
pixel_values = torch.randn(batch_size, num_channels, height, width, device=torch.cuda.current_device(), dtype=torch.float) pixel_values = torch.randn(batch_size,
labels = torch.randint(0, num_labels, (batch_size, ), device=torch.cuda.current_device(), dtype=torch.int64) num_channels,
height,
width,
device=torch.cuda.current_device(),
dtype=torch.float)
labels = torch.randint(0, num_labels, (batch_size,), device=torch.cuda.current_device(), dtype=torch.int64)
return pixel_values, labels return pixel_values, labels
...@@ -55,11 +59,11 @@ def main(): ...@@ -55,11 +59,11 @@ def main():
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
else: else:
transformers.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error()
# Whether to set limit on memory capacity # Whether to set limit on memory capacity
if args.mem_cap > 0: if args.mem_cap > 0:
colo_memory_cap(args.mem_cap) colo_memory_cap(args.mem_cap)
# Build ViT model # Build ViT model
config = ViTConfig.from_pretrained(args.model_name_or_path) config = ViTConfig.from_pretrained(args.model_name_or_path)
model = ViTForImageClassification(config) model = ViTForImageClassification(config)
...@@ -75,11 +79,7 @@ def main(): ...@@ -75,11 +79,7 @@ def main():
if args.plugin.startswith('torch_ddp'): if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin() plugin = TorchDDPPlugin()
elif args.plugin == 'gemini': elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(), plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
elif args.plugin == 'low_level_zero': elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5) plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0]) logger.info(f"Set plugin as {args.plugin}", ranks=[0])
...@@ -90,16 +90,15 @@ def main(): ...@@ -90,16 +90,15 @@ def main():
# Set booster # Set booster
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, _, _ = booster.boost(model, optimizer) model, optimizer, _, _, _ = booster.boost(model, optimizer)
# Start training. # Start training.
logger.info(f"Start testing", ranks=[0]) logger.info(f"Start testing", ranks=[0])
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master()) progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
torch.cuda.synchronize() torch.cuda.synchronize()
model.train() model.train()
start_time = time.time() start_time = time.time()
for _ in range(args.max_train_steps): for _ in range(args.max_train_steps):
pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224) pixel_values, labels = get_data(args.batch_size, args.num_labels, 3, 224, 224)
...@@ -111,18 +110,19 @@ def main(): ...@@ -111,18 +110,19 @@ def main():
torch.cuda.synchronize() torch.cuda.synchronize()
progress_bar.update(1) progress_bar.update(1)
# Compute Statistics # Compute Statistics
end_time = time.time() end_time = time.time()
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time)) throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True) max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
logger.info(f"Testing finished, " logger.info(
f"batch size per gpu: {args.batch_size}, " f"Testing finished, "
f"plugin: {args.plugin}, " f"batch size per gpu: {args.batch_size}, "
f"throughput: {throughput}, " f"plugin: {args.plugin}, "
f"maximum memory usage per gpu: {max_mem}.", f"throughput: {throughput}, "
ranks=[0]) f"maximum memory usage per gpu: {max_mem}.",
ranks=[0])
if __name__ == "__main__": if __name__ == "__main__":
......
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import transformers import transformers
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor from args import parse_demo_args
from data import BeansDataset, beans_collator
from tqdm import tqdm from tqdm import tqdm
from transformers import ViTConfig, ViTForImageClassification, ViTImageProcessor
import colossalai import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.utils import get_current_device
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from args import parse_demo_args from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from data import BeansDataset, beans_collator from colossalai.nn.optimizer import HybridAdam
from colossalai.utils import get_current_device
def move_to_cuda(batch, device): def move_to_cuda(batch, device):
...@@ -22,12 +21,12 @@ def move_to_cuda(batch, device): ...@@ -22,12 +21,12 @@ def move_to_cuda(batch, device):
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator): def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
torch.cuda.synchronize() torch.cuda.synchronize()
model.train() model.train()
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar: with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
for batch in pbar: for batch in pbar:
# Foward # Foward
...@@ -47,7 +46,7 @@ def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coor ...@@ -47,7 +46,7 @@ def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coor
@torch.no_grad() @torch.no_grad()
def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator): def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator):
model.eval() model.eval()
accum_loss = torch.zeros(1, device=get_current_device()) accum_loss = torch.zeros(1, device=get_current_device())
total_num = torch.zeros(1, device=get_current_device()) total_num = torch.zeros(1, device=get_current_device())
...@@ -76,9 +75,7 @@ def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator): ...@@ -76,9 +75,7 @@ def evaluate_model(epoch, model, eval_dataloader, num_labels, coordinator):
print(f"Evaluation result for epoch {epoch + 1}: \ print(f"Evaluation result for epoch {epoch + 1}: \
average_loss={avg_loss}, \ average_loss={avg_loss}, \
accuracy={accuracy}.") accuracy={accuracy}.")
def main(): def main():
...@@ -102,14 +99,13 @@ def main(): ...@@ -102,14 +99,13 @@ def main():
train_dataset = BeansDataset(image_processor, split='train') train_dataset = BeansDataset(image_processor, split='train')
eval_dataset = BeansDataset(image_processor, split='validation') eval_dataset = BeansDataset(image_processor, split='validation')
# Load pretrained ViT model # Load pretrained ViT model
config = ViTConfig.from_pretrained(args.model_name_or_path) config = ViTConfig.from_pretrained(args.model_name_or_path)
config.num_labels = train_dataset.num_labels config.num_labels = train_dataset.num_labels
config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)} config.id2label = {str(i): c for i, c in enumerate(train_dataset.label_names)}
config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)} config.label2id = {c: str(i) for i, c in enumerate(train_dataset.label_names)}
model = ViTForImageClassification.from_pretrained(args.model_name_or_path, model = ViTForImageClassification.from_pretrained(args.model_name_or_path,
config=config, config=config,
ignore_mismatched_sizes=True) ignore_mismatched_sizes=True)
logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0]) logger.info(f"Finish loading model from {args.model_name_or_path}", ranks=[0])
...@@ -123,26 +119,22 @@ def main(): ...@@ -123,26 +119,22 @@ def main():
if args.plugin.startswith('torch_ddp'): if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin() plugin = TorchDDPPlugin()
elif args.plugin == 'gemini': elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(), plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
elif args.plugin == 'low_level_zero': elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5) plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0]) logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Prepare dataloader # Prepare dataloader
train_dataloader = plugin.prepare_dataloader(train_dataset, train_dataloader = plugin.prepare_dataloader(train_dataset,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
collate_fn=beans_collator) collate_fn=beans_collator)
eval_dataloader = plugin.prepare_dataloader(eval_dataset, eval_dataloader = plugin.prepare_dataloader(eval_dataset,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
collate_fn=beans_collator) collate_fn=beans_collator)
# Set optimizer # Set optimizer
optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay) optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
...@@ -156,11 +148,11 @@ def main(): ...@@ -156,11 +148,11 @@ def main():
# Set booster # Set booster
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model, model, optimizer, _, train_dataloader, lr_scheduler = booster.boost(model=model,
optimizer=optimizer, optimizer=optimizer,
dataloader=train_dataloader, dataloader=train_dataloader,
lr_scheduler=lr_scheduler) lr_scheduler=lr_scheduler)
# Finetuning # Finetuning
logger.info(f"Start finetuning", ranks=[0]) logger.info(f"Start finetuning", ranks=[0])
for epoch in range(args.num_epoch): for epoch in range(args.num_epoch):
...@@ -174,4 +166,4 @@ def main(): ...@@ -174,4 +166,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
\ No newline at end of file
...@@ -7,6 +7,14 @@ This directory includes two parts: Using the Booster API finetune Huggingface Be ...@@ -7,6 +7,14 @@ This directory includes two parts: Using the Booster API finetune Huggingface Be
bash test_ci.sh bash test_ci.sh
``` ```
### Results on 2-GPU
| Plugin | Accuracy | F1-score |
| -------------- | -------- | -------- |
| torch_ddp | 84.4% | 88.6% |
| torch_ddp_fp16 | 84.7% | 88.8% |
| gemini | 84.0% | 88.4% |
## Benchmark ## Benchmark
``` ```
bash benchmark.sh bash benchmark.sh
...@@ -14,9 +22,9 @@ bash benchmark.sh ...@@ -14,9 +22,9 @@ bash benchmark.sh
Now include these metrics in benchmark: CUDA mem occupy, throughput and the number of model parameters. If you have custom metrics, you can add them to benchmark_util. Now include these metrics in benchmark: CUDA mem occupy, throughput and the number of model parameters. If you have custom metrics, you can add them to benchmark_util.
## Results ### Results
### Bert #### Bert
| | max cuda mem | throughput(sample/s) | params | | | max cuda mem | throughput(sample/s) | params |
| :-----| -----------: | :--------: | :----: | | :-----| -----------: | :--------: | :----: |
...@@ -25,10 +33,10 @@ Now include these metrics in benchmark: CUDA mem occupy, throughput and the numb ...@@ -25,10 +33,10 @@ Now include these metrics in benchmark: CUDA mem occupy, throughput and the numb
| gemini | 11.0 GB | 12.9 | 82M | | gemini | 11.0 GB | 12.9 | 82M |
| low_level_zero | 11.29 G | 14.7 | 82M | | low_level_zero | 11.29 G | 14.7 | 82M |
### AlBert #### AlBert
| | max cuda mem | throughput(sample/s) | params | | | max cuda mem | throughput(sample/s) | params |
| :-----| -----------: | :--------: | :----: | | :-----| -----------: | :--------: | :----: |
| ddp | OOM | | | | ddp | OOM | | |
| ddp_fp16 | OOM | | | | ddp_fp16 | OOM | | |
| gemini | 69.39 G | 1.3 | 208M | | gemini | 69.39 G | 1.3 | 208M |
| low_level_zero | 56.89 G | 1.4 | 208M | | low_level_zero | 56.89 G | 1.4 | 208M |
\ No newline at end of file
...@@ -38,8 +38,8 @@ def move_to_cuda(batch): ...@@ -38,8 +38,8 @@ def move_to_cuda(batch):
@torch.no_grad() @torch.no_grad()
def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int, task_name: str, def evaluate_model(model: nn.Module, test_dataloader: Union[DataLoader, List[DataLoader]], num_labels: int,
eval_splits: List[str], coordinator: DistCoordinator): task_name: str, eval_splits: List[str], coordinator: DistCoordinator):
metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size) metric = evaluate.load("glue", task_name, process_id=coordinator.rank, num_process=coordinator.world_size)
model.eval() model.eval()
...@@ -142,7 +142,7 @@ def main(): ...@@ -142,7 +142,7 @@ def main():
if args.plugin.startswith('torch_ddp'): if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin() plugin = TorchDDPPlugin()
elif args.plugin == 'gemini': elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy='cuda', strict_ddp_mode=True, initial_scale=2**5) plugin = GeminiPlugin(initial_scale=2**5)
elif args.plugin == 'low_level_zero': elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5) plugin = LowLevelZeroPlugin(initial_scale=2**5)
...@@ -208,7 +208,7 @@ def main(): ...@@ -208,7 +208,7 @@ def main():
train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator) train_epoch(epoch, model, optimizer, lr_scheduler, train_dataloader, booster, coordinator)
results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits, results = evaluate_model(model, test_dataloader, data_builder.num_labels, args.task, data_builder.eval_splits,
coordinator) coordinator)
if coordinator.is_master(): if coordinator.is_master():
print(results) print(results)
......
...@@ -4,9 +4,6 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"} ...@@ -4,9 +4,6 @@ export DISTPLAN=${DISTPLAN:-"CAI_Gemini"}
# The following options only valid when DISTPLAN="colossalai" # The following options only valid when DISTPLAN="colossalai"
export GPUNUM=${GPUNUM:-1} export GPUNUM=${GPUNUM:-1}
export TPDEGREE=${TPDEGREE:-1}
export PLACEMENT=${PLACEMENT:-"cpu"}
export USE_SHARD_INIT=${USE_SHARD_INIT:-False}
export BATCH_SIZE=${BATCH_SIZE:-16} export BATCH_SIZE=${BATCH_SIZE:-16}
export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"} export MODEL_TYPE=${MODEL_TYPE:-"gpt2_medium"}
export TRAIN_STEP=${TRAIN_STEP:-10} export TRAIN_STEP=${TRAIN_STEP:-10}
...@@ -21,11 +18,8 @@ fi ...@@ -21,11 +18,8 @@ fi
mkdir -p gemini_logs mkdir -p gemini_logs
torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \ torchrun --standalone --nproc_per_node=${GPUNUM} ./train_gpt_demo.py \
--tp_degree=${TPDEGREE} \
--model_type=${MODEL_TYPE} \ --model_type=${MODEL_TYPE} \
--batch_size=${BATCH_SIZE} \ --batch_size=${BATCH_SIZE} \
--placement=${PLACEMENT} \
${USE_SHARD_INIT} \
--distplan=${DISTPLAN} \ --distplan=${DISTPLAN} \
--train_step=${TRAIN_STEP} \ --train_step=${TRAIN_STEP} \
2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log 2>&1 | tee ./gemini_logs/${MODEL_TYPE}_${DISTPLAN}_gpu_${GPUNUM}_bs_${BATCH_SIZE}_tp_${TPDEGREE}_${PLACEMENT}.log
...@@ -6,29 +6,17 @@ for MODEL_TYPE in "gpt2_medium"; do ...@@ -6,29 +6,17 @@ for MODEL_TYPE in "gpt2_medium"; do
for DISTPLAN in "CAI_Gemini"; do for DISTPLAN in "CAI_Gemini"; do
for BATCH_SIZE in 2; do for BATCH_SIZE in 2; do
for GPUNUM in 1 4; do for GPUNUM in 1 4; do
for TPDEGREE in 1 2; do MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then bash ./run_gemini.sh
continue
fi
for PLACEMENT in "cpu" "auto"; do
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE} PLACEMENT=${PLACEMENT} \
bash ./run_gemini.sh
done
done
done done
done done
done done
for DISTPLAN in "zero1" "zero2"; do for DISTPLAN in "CAI_ZeRO2" "CAI_ZeRO1"; do
for BATCH_SIZE in 2; do for BATCH_SIZE in 2; do
for GPUNUM in 1 4; do for GPUNUM in 1 4; do
for TPDEGREE in 1; do MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} \
if [ ${TPDEGREE} -gt ${GPUNUM} ]; then bash ./run_gemini.sh
continue
fi
MODEL_TYPE=${MODEL_TYPE} DISTPLAN=${DISTPLAN} BATCH_SIZE=${BATCH_SIZE} GPUNUM=${GPUNUM} TPDEGREE=${TPDEGREE}\
bash ./run_gemini.sh
done
done done
done done
done done
......
import os import os
from contextlib import nullcontext
from functools import partial from functools import partial
from time import time from time import time
...@@ -13,11 +14,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP ...@@ -13,11 +14,10 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
CAI_VERSION = colossalai.__version__ CAI_VERSION = colossalai.__version__
...@@ -30,24 +30,6 @@ def parse_args(): ...@@ -30,24 +30,6 @@ def parse_args():
default='CAI_Gemini', default='CAI_Gemini',
help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].", help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
) )
parser.add_argument(
"--tp_degree",
type=int,
default=1,
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--shardinit",
action='store_true',
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
)
parser.add_argument( parser.add_argument(
"--batch_size", "--batch_size",
type=int, type=int,
...@@ -71,20 +53,6 @@ def parse_args(): ...@@ -71,20 +53,6 @@ def parse_args():
return args return args
# Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
class GPTLMLoss(nn.Module): class GPTLMLoss(nn.Module):
def __init__(self): def __init__(self):
...@@ -140,47 +108,6 @@ def set_cpu_maximum_parallelism(): ...@@ -140,47 +108,6 @@ def set_cpu_maximum_parallelism():
print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.") print(f"environmental variable OMP_NUM_THREADS is set to {max_concurrency}.")
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
Sharding the Model Parameters.
Args:
model (torch.nn.Module): a torch module to be sharded
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
# NOTE() a param maybe shared by two modules
if hasattr(param, 'visited'):
continue
# if shard init, then convert param to replica and use the dp-only ProcessGroup
param: ColoParameter = param
param.set_dist_spec(ReplicaSpec())
param.set_process_group(pg)
# shard it w.r.t tp pattern
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg) # column slice
# keep the shape of the output from c_fc
param.compute_spec.set_output_replicate(False)
else:
param.set_dist_spec(ReplicaSpec())
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
elif 'wte' in mn or 'wpe' in mn:
split_param_col_tp1d(param, pg) # column slice
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg) # column slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
def main(): def main():
# version check # version check
# this example is supposed to work for versions greater than 0.2.0 # this example is supposed to work for versions greater than 0.2.0
...@@ -213,30 +140,13 @@ def main(): ...@@ -213,30 +140,13 @@ def main():
# build criterion # build criterion
criterion = GPTLMLoss() criterion = GPTLMLoss()
torch.manual_seed(123) torch.manual_seed(123)
if args.distplan.startswith("CAI"): if args.distplan.startswith("CAI"):
# all param must use the same process group. ctx = LazyInitContext(default_device=get_current_device()) if args.distplan == "CAI_Gemini" else nullcontext()
world_size = torch.distributed.get_world_size()
shard_pg = ProcessGroup(tp_degree=world_size) if args.shardinit else None
default_dist_spec = ShardSpec([-1], [world_size]) if args.shardinit else None
if args.shardinit and args.distplan != "CAI_Gemini":
raise RuntimeError("You can only use shardinit with CAI_Gemini")
# build GPT model # build GPT model
with ColoInitContext(device=get_current_device(), with ctx:
dtype=torch.half,
default_dist_spec=default_dist_spec,
default_pg=shard_pg):
model = model_builder(args.model_type)(checkpoint=True) model = model_builder(args.model_type)(checkpoint=True)
tp_pg = ProcessGroup(tp_degree=args.tp_degree)
# Tensor Parallelism (TP)
# You should notice that v0.1.10 is not compatible with TP degree > 1
if args.tp_degree > 1:
tensor_parallelize(model, tp_pg)
# assign running configurations # assign running configurations
if args.distplan == "CAI_ZeRO1": if args.distplan == "CAI_ZeRO1":
zero_stage = 1 zero_stage = 1
...@@ -254,13 +164,7 @@ def main(): ...@@ -254,13 +164,7 @@ def main():
overlap_communication=True, overlap_communication=True,
verbose=True) verbose=True)
elif args.distplan == "CAI_Gemini": elif args.distplan == "CAI_Gemini":
plugin = GeminiPlugin(device=get_current_device(), plugin = GeminiPlugin(search_range_m=128, hidden_dim=model.config.n_embd)
placement_policy=args.placement,
pin_memory=True,
strict_ddp_mode=args.tp_degree == 1,
search_range_m=128,
hidden_dim=model.config.n_embd,
gpu_margin_mem_ratio=0.)
else: else:
raise RuntimeError raise RuntimeError
......
import time import time
import torch import torch
import tqdm
import transformers import transformers
from args import parse_benchmark_args
from transformers import AutoConfig, OPTForCausalLM from transformers import AutoConfig, OPTForCausalLM
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
import tqdm
import colossalai import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from args import parse_benchmark_args from colossalai.nn.optimizer import HybridAdam
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt") require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
...@@ -61,11 +57,11 @@ def main(): ...@@ -61,11 +57,11 @@ def main():
transformers.utils.logging.set_verbosity_info() transformers.utils.logging.set_verbosity_info()
else: else:
transformers.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error()
# Whether to set limit of memory capacity # Whether to set limit of memory capacity
if args.mem_cap > 0: if args.mem_cap > 0:
colo_memory_cap(args.mem_cap) colo_memory_cap(args.mem_cap)
# Build OPT model # Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path) config = AutoConfig.from_pretrained(args.model_name_or_path)
model = OPTForCausalLM(config=config) model = OPTForCausalLM(config=config)
...@@ -81,11 +77,7 @@ def main(): ...@@ -81,11 +77,7 @@ def main():
if args.plugin.startswith('torch_ddp'): if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin() plugin = TorchDDPPlugin()
elif args.plugin == 'gemini': elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(), plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
elif args.plugin == 'low_level_zero': elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5) plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0]) logger.info(f"Set plugin as {args.plugin}", ranks=[0])
...@@ -96,18 +88,18 @@ def main(): ...@@ -96,18 +88,18 @@ def main():
# Set booster # Set booster
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, _, _ = booster.boost(model, optimizer) model, optimizer, _, _, _ = booster.boost(model, optimizer)
SEQ_LEN = 1024 SEQ_LEN = 1024
VOCAB_SIZE = 50257 VOCAB_SIZE = 50257
# Start training. # Start training.
logger.info(f"Start testing", ranks=[0]) logger.info(f"Start testing", ranks=[0])
progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master()) progress_bar = tqdm.tqdm(total=args.max_train_steps, desc="Training Step", disable=not coordinator.is_master())
torch.cuda.synchronize() torch.cuda.synchronize()
model.train() model.train()
start_time = time.time() start_time = time.time()
for _ in range(args.max_train_steps): for _ in range(args.max_train_steps):
input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE) input_ids, attn_mask = get_data(args.batch_size, SEQ_LEN, VOCAB_SIZE)
...@@ -119,18 +111,19 @@ def main(): ...@@ -119,18 +111,19 @@ def main():
torch.cuda.synchronize() torch.cuda.synchronize()
progress_bar.update(1) progress_bar.update(1)
# Compute Statistics # Compute Statistics
end_time = time.time() end_time = time.time()
throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time)) throughput = "{:.4f}".format((world_size * args.max_train_steps * args.batch_size) / (end_time - start_time))
max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True) max_mem = format_num(torch.cuda.max_memory_allocated(device=torch.cuda.current_device()), bytes=True)
logger.info(f"Testing finished, " logger.info(
f"batch size per gpu: {args.batch_size}, " f"Testing finished, "
f"plugin: {args.plugin}, " f"batch size per gpu: {args.batch_size}, "
f"throughput: {throughput}, " f"plugin: {args.plugin}, "
f"maximum memory usage per gpu: {max_mem}.", f"throughput: {throughput}, "
ranks=[0]) f"maximum memory usage per gpu: {max_mem}.",
ranks=[0])
if __name__ == "__main__": if __name__ == "__main__":
......
import time import time
import torch
import datasets import datasets
import torch
import transformers import transformers
from transformers import AutoConfig, OPTForCausalLM, AutoTokenizer from args import parse_demo_args
from transformers import get_linear_schedule_with_warmup from data import NetflixDataset, netflix_collator
from transformers.utils.versions import require_version
from tqdm import tqdm from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer, OPTForCausalLM, get_linear_schedule_with_warmup
from transformers.utils.versions import require_version
import colossalai import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ProcessGroup, ShardSpec
from colossalai.utils import get_current_device
from colossalai.zero import ColoInitContext
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers, get_dist_logger
from args import parse_demo_args from colossalai.nn.optimizer import HybridAdam
from data import NetflixDataset, netflix_collator
require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt") require_version("transformers>=4.20.0", "To fix: pip install -r requirements.txt")
...@@ -30,18 +25,18 @@ def move_to_cuda(batch, device): ...@@ -30,18 +25,18 @@ def move_to_cuda(batch, device):
def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator): def train_epoch(epoch, model, optimizer, lr_scheduler, dataloader, booster, coordinator):
torch.cuda.synchronize() torch.cuda.synchronize()
model.train() model.train()
with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar: with tqdm(dataloader, desc=f'Epoch [{epoch + 1}]', disable=not coordinator.is_master()) as pbar:
for batch in pbar: for batch in pbar:
# Forward # Forward
optimizer.zero_grad() optimizer.zero_grad()
batch = move_to_cuda(batch, torch.cuda.current_device()) batch = move_to_cuda(batch, torch.cuda.current_device())
outputs = model(use_cache=False, **batch) outputs = model(use_cache=False, **batch)
loss = outputs['loss'] loss = outputs['loss']
...@@ -72,7 +67,7 @@ def main(): ...@@ -72,7 +67,7 @@ def main():
else: else:
datasets.utils.logging.set_verbosity_error() datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error() transformers.utils.logging.set_verbosity_error()
# Build OPT model # Build OPT model
config = AutoConfig.from_pretrained(args.model_name_or_path) config = AutoConfig.from_pretrained(args.model_name_or_path)
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config) model = OPTForCausalLM.from_pretrained(args.model_name_or_path, config=config)
...@@ -88,43 +83,35 @@ def main(): ...@@ -88,43 +83,35 @@ def main():
if args.plugin.startswith('torch_ddp'): if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin() plugin = TorchDDPPlugin()
elif args.plugin == 'gemini': elif args.plugin == 'gemini':
plugin = GeminiPlugin(device=get_current_device(), plugin = GeminiPlugin(offload_optim_frac=1.0, pin_memory=True, initial_scale=2**5)
placement_policy='cpu',
pin_memory=True,
strict_ddp_mode=True,
initial_scale=2**5)
elif args.plugin == 'low_level_zero': elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2**5) plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"Set plugin as {args.plugin}", ranks=[0]) logger.info(f"Set plugin as {args.plugin}", ranks=[0])
# Prepare tokenizer and dataloader # Prepare tokenizer and dataloader
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
dataset = NetflixDataset(tokenizer) dataset = NetflixDataset(tokenizer)
dataloader = plugin.prepare_dataloader(dataset, dataloader = plugin.prepare_dataloader(dataset,
batch_size=args.batch_size, batch_size=args.batch_size,
shuffle=True, shuffle=True,
drop_last=True, drop_last=True,
collate_fn=netflix_collator) collate_fn=netflix_collator)
# Set optimizer # Set optimizer
optimizer = HybridAdam(model.parameters(), optimizer = HybridAdam(model.parameters(), lr=(args.learning_rate * world_size), weight_decay=args.weight_decay)
lr=(args.learning_rate * world_size),
weight_decay=args.weight_decay)
# Set lr scheduler # Set lr scheduler
total_steps = len(dataloader) * args.num_epoch total_steps = len(dataloader) * args.num_epoch
num_warmup_steps = int(args.warmup_ratio * total_steps) num_warmup_steps = int(args.warmup_ratio * total_steps)
lr_scheduler = get_linear_schedule_with_warmup( lr_scheduler = get_linear_schedule_with_warmup(optimizer,
optimizer, num_warmup_steps=num_warmup_steps,
num_warmup_steps=num_warmup_steps, num_training_steps=len(dataloader) * args.num_epoch)
num_training_steps=len(dataloader) * args.num_epoch
)
# Set booster # Set booster
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)
model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model, model, optimizer, _, dataloader, lr_scheduler = booster.boost(model=model,
optimizer=optimizer, optimizer=optimizer,
dataloader=dataloader, dataloader=dataloader,
lr_scheduler=lr_scheduler) lr_scheduler=lr_scheduler)
# Start finetuning # Start finetuning
......
import gzip import gzip
import random from contextlib import nullcontext
from functools import partial from functools import partial
from time import time from time import time
...@@ -8,20 +8,17 @@ import torch ...@@ -8,20 +8,17 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import tqdm import tqdm
from packaging import version
from colossalai.nn import HybridAdam
from palm_pytorch import PaLM from palm_pytorch import PaLM
from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper from palm_pytorch.autoregressive_wrapper import AutoregressiveWrapper
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import MultiTimer, get_current_device
from colossalai.zero import ColoInitContext, GeminiAdamOptimizer, ZeroDDP
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.booster.plugin import GeminiPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn import HybridAdam
from colossalai.utils import get_current_device
# constants # constants
...@@ -44,23 +41,10 @@ def parse_args(): ...@@ -44,23 +41,10 @@ def parse_args():
help="The distributed plan [colossalai, pytorch].", help="The distributed plan [colossalai, pytorch].",
) )
parser.add_argument( parser.add_argument(
"--tp_degree", "--offload_optim_frac",
type=int, type=float,
default=1, default=1.0,
help="Tensor Parallelism Degree. Valid when using colossalai as dist plan.", help="Fraction of optimizer states to be offloaded. This is only used for gemini.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini. Valid when using colossalai as dist plan.",
)
parser.add_argument(
"--shardinit",
type=bool,
default=False,
help=
"Shard the tensors when init the model to shrink peak memory size on the assigned device. Valid when using colossalai as dist plan.",
) )
parser.add_argument('-p', parser.add_argument('-p',
'--plugin', '--plugin',
...@@ -111,51 +95,6 @@ def get_model_size(model: nn.Module): ...@@ -111,51 +95,6 @@ def get_model_size(model: nn.Module):
return total_numel return total_numel
# Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
Sharding the Model Parameters.
Args:
model (torch.nn.Module): a torch module to be sharded
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
if hasattr(param, 'visited'):
continue
param.set_dist_spec(ReplicaSpec())
if 'net.0' in mn:
split_param_col_tp1d(param, pg) # column slice
elif 'to_q' in mn:
split_param_col_tp1d(param, pg) # column slice
elif 'to_kv' in mn:
split_param_row_tp1d(param, pg) # row slice
elif 'to_out' in mn:
split_param_row_tp1d(param, pg) # row slice
elif '1.1' in mn:
split_param_col_tp1d(param, pg) # column slice
elif '1.2' in mn:
split_param_row_tp1d(param, pg) # row slice
else:
param.set_dist_spec(ReplicaSpec())
param.visited = True
args = parse_args() args = parse_args()
if args.distplan not in ["colossalai", "pytorch"]: if args.distplan not in ["colossalai", "pytorch"]:
raise TypeError(f"{args.distplan} is error") raise TypeError(f"{args.distplan} is error")
...@@ -212,23 +151,18 @@ if args.distplan == "colossalai": ...@@ -212,23 +151,18 @@ if args.distplan == "colossalai":
if args.plugin.startswith('torch_ddp'): if args.plugin.startswith('torch_ddp'):
plugin = TorchDDPPlugin() plugin = TorchDDPPlugin()
elif args.plugin == 'gemini': elif args.plugin == 'gemini':
plugin = GeminiPlugin(placement_policy=args.placement, strict_ddp_mode=True, initial_scale=2 ** 5) plugin = GeminiPlugin(offload_optim_frac=args.offload_optim_frac, initial_scale=2**5)
elif args.plugin == 'low_level_zero': elif args.plugin == 'low_level_zero':
plugin = LowLevelZeroPlugin(initial_scale=2 ** 5) plugin = LowLevelZeroPlugin(initial_scale=2**5)
logger.info(f"plugin: {plugin}") logger.info(f"plugin: {plugin}")
booster = Booster(plugin=plugin, **booster_kwargs) booster = Booster(plugin=plugin, **booster_kwargs)
default_pg = ProcessGroup(tp_degree=args.tp_degree) ctx = LazyInitContext(default_device=get_current_device()) if args.plugin == 'gemini' else nullcontext()
default_dist_spec = ShardSpec([-1], [args.tp_degree]) if args.shardinit else None
ctx = ColoInitContext(device='cpu', default_dist_spec=default_dist_spec, default_pg=default_pg)
with ctx: with ctx:
model = PaLM(num_tokens=50304, dim=4096, depth=64) model = PaLM(num_tokens=50304, dim=4096, depth=64)
model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN) model = AutoregressiveWrapper(model, max_seq_len=SEQ_LEN)
pg = default_pg
tensor_parallelize(model, pg)
# optimizer # optimizer
optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5) optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE, initial_scale=2**5)
......
...@@ -3,5 +3,5 @@ torch >= 1.8.1 ...@@ -3,5 +3,5 @@ torch >= 1.8.1
datasets >= 1.8.0 datasets >= 1.8.0
sentencepiece != 0.1.92 sentencepiece != 0.1.92
protobuf protobuf
accelerate == 0.13.2 accelerate
transformers transformers
...@@ -30,7 +30,7 @@ from itertools import chain ...@@ -30,7 +30,7 @@ from itertools import chain
import datasets import datasets
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import transformers import transformers.utils.logging as logging
from accelerate.utils import set_seed from accelerate.utils import set_seed
from context import barrier_context from context import barrier_context
from datasets import load_dataset from datasets import load_dataset
...@@ -57,7 +57,7 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger ...@@ -57,7 +57,7 @@ from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader from colossalai.utils import get_current_device, get_dataloader
from colossalai.zero import ColoInitContext, ZeroDDP, ZeroOptimizer from colossalai.zero import GeminiOptimizer
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
...@@ -292,10 +292,10 @@ def main(): ...@@ -292,10 +292,10 @@ def main():
if is_main_process: if is_main_process:
datasets.utils.logging.set_verbosity_warning() datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_info() logging.set_verbosity_info()
else: else:
datasets.utils.logging.set_verbosity_error() datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error() logging.set_verbosity_error()
if args.mem_cap > 0: if args.mem_cap > 0:
colo_memory_cap(args.mem_cap) colo_memory_cap(args.mem_cap)
...@@ -391,16 +391,28 @@ def main(): ...@@ -391,16 +391,28 @@ def main():
else: else:
init_dev = get_current_device() init_dev = get_current_device()
cai_version = colossalai.__version__
logger.info(f'using Colossal-AI version {cai_version}')
# build model # build model
if version.parse(cai_version) >= version.parse("0.3.1"):
from contextlib import nullcontext
from colossalai.lazy import LazyInitContext
ctx = LazyInitContext(
default_device=init_dev
) if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b' else nullcontext()
else:
from colossalai.zero import ColoInitContext
ctx = ColoInitContext(device=init_dev)
if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b': if args.model_name_or_path is None or args.model_name_or_path == 'facebook/opt-13b':
# currently, there has a bug in pretrained opt-13b # currently, there has a bug in pretrained opt-13b
# we can not import it until huggingface fix it # we can not import it until huggingface fix it
logger.info("Train a new model from scratch", ranks=[0]) logger.info("Train a new model from scratch", ranks=[0])
with ColoInitContext(device=init_dev): with ctx:
model = OPTForCausalLM(config) model = OPTForCausalLM(config)
else: else:
logger.info("Finetune a pre-trained model", ranks=[0]) logger.info("Finetune a pre-trained model", ranks=[0])
with ColoInitContext(device=init_dev): with ctx:
model = OPTForCausalLM.from_pretrained(args.model_name_or_path, model = OPTForCausalLM.from_pretrained(args.model_name_or_path,
from_tf=bool(".ckpt" in args.model_name_or_path), from_tf=bool(".ckpt" in args.model_name_or_path),
config=config, config=config,
...@@ -410,9 +422,10 @@ def main(): ...@@ -410,9 +422,10 @@ def main():
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
PLACEMENT_POLICY = 'auto' PLACEMENT_POLICY = 'auto'
cai_version = colossalai.__version__ if version.parse(cai_version) >= version.parse("0.3.1"):
logger.info(f'using Colossal-AI version {cai_version}') from colossalai.zero import GeminiDDP
if version.parse(cai_version) > version.parse("0.1.10"): model = GeminiDDP(model, offload_optim_frac=1.0, pin_memory=True)
elif version.parse(cai_version) > version.parse("0.1.10"):
try: try:
from colossalai.nn.parallel import GeminiDDP from colossalai.nn.parallel import GeminiDDP
except ImportError: except ImportError:
...@@ -536,7 +549,6 @@ def main(): ...@@ -536,7 +549,6 @@ def main():
] ]
optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate) optimizer = HybridAdam(optimizer_grouped_parameters, lr=args.learning_rate)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**14)
# Scheduler and math around the number of training steps. # Scheduler and math around the number of training steps.
overrode_max_train_steps = False overrode_max_train_steps = False
...@@ -551,6 +563,7 @@ def main(): ...@@ -551,6 +563,7 @@ def main():
num_warmup_steps=args.num_warmup_steps, num_warmup_steps=args.num_warmup_steps,
num_training_steps=args.max_train_steps, num_training_steps=args.max_train_steps,
) )
optimizer = GeminiOptimizer(optimizer, model, initial_scale=2**14)
# We need to recalculate our total training steps as the size of the training dataloader may have changed. # We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
......
...@@ -4,9 +4,9 @@ set -xue ...@@ -4,9 +4,9 @@ set -xue
pip install -r requirements.txt pip install -r requirements.txt
BS=8 BS=4
MEMCAP=0 MEMCAP=0
GPUNUM=2 GPUNUM=4
MODLE="facebook/opt-125m" MODLE="facebook/opt-125m"
torchrun \ torchrun \
......
...@@ -4,4 +4,5 @@ markers = ...@@ -4,4 +4,5 @@ markers =
gpu: tests which requires a single GPU gpu: tests which requires a single GPU
dist: tests which are run in a multi-GPU or multi-machine environment dist: tests which are run in a multi-GPU or multi-machine environment
experiment: tests for experimental features experiment: tests for experimental features
addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe addopts = --ignore=tests/test_analyzer --ignore=tests/test_auto_parallel --ignore=tests/test_autochunk --ignore=tests/test_moe --ignore=tests/test_fx
...@@ -17,6 +17,13 @@ def data_gen_fn(): ...@@ -17,6 +17,13 @@ def data_gen_fn():
return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) return dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
def data_gen_for_pretrain():
inputs = data_gen_fn()
inputs['labels'] = inputs['input_ids'].clone()
inputs['sentence_order_label'] = torch.zeros(BATCH_SIZE, dtype=torch.int64)
return inputs
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
config = transformers.AlbertConfig(embedding_size=128, config = transformers.AlbertConfig(embedding_size=128,
...@@ -26,14 +33,14 @@ config = transformers.AlbertConfig(embedding_size=128, ...@@ -26,14 +33,14 @@ config = transformers.AlbertConfig(embedding_size=128,
intermediate_size=256) intermediate_size=256)
model_zoo.register(name='transformers_albert', model_zoo.register(name='transformers_albert',
model_fn=lambda: transformers.AlbertModel(config), model_fn=lambda: transformers.AlbertModel(config, add_pooling_layer=False),
data_gen_fn=data_gen_fn, data_gen_fn=data_gen_fn,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
model_attribute=ModelAttribute(has_control_flow=True)) model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_albert_for_pretraining', model_zoo.register(name='transformers_albert_for_pretraining',
model_fn=lambda: transformers.AlbertForPreTraining(config), model_fn=lambda: transformers.AlbertForPreTraining(config),
data_gen_fn=data_gen_fn, data_gen_fn=data_gen_for_pretrain,
output_transform_fn=output_transform_fn, output_transform_fn=lambda x: dict(loss=x.loss),
model_attribute=ModelAttribute(has_control_flow=True)) model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_albert_for_masked_lm', model_zoo.register(name='transformers_albert_for_masked_lm',
model_fn=lambda: transformers.AlbertForMaskedLM(config), model_fn=lambda: transformers.AlbertForMaskedLM(config),
......
...@@ -113,6 +113,7 @@ def data_gen_for_qa(): ...@@ -113,6 +113,7 @@ def data_gen_for_qa():
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
# define loss funciton # define loss funciton
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state
)) ))
loss_fn = lambda x: x.loss loss_fn = lambda x: x.loss
...@@ -126,7 +127,7 @@ config = transformers.BertConfig(hidden_size=128, ...@@ -126,7 +127,7 @@ config = transformers.BertConfig(hidden_size=128,
# register the BERT variants # register the BERT variants
model_zoo.register(name='transformers_bert', model_zoo.register(name='transformers_bert',
model_fn=lambda: transformers.BertModel(config), model_fn=lambda: transformers.BertModel(config, add_pooling_layer=False),
data_gen_fn=data_gen, data_gen_fn=data_gen,
output_transform_fn=output_transform_fn, output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_bert_model, loss_fn=loss_fn_for_bert_model,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment