Unverified Commit a25f7553 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[example] add TP to GPT example (#1828)

parent 49216d7a
env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=2 train_gpt_demo.py 2>&1 | tee run.log env OMP_NUM_THREADS=16 torchrun --standalone --nproc_per_node=4 train_gpt_demo.py --tp_degree=2 --placement='cpu' 2>&1 | tee run.log
...@@ -10,13 +10,48 @@ import colossalai ...@@ -10,13 +10,48 @@ import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer from colossalai.zero import ZeroOptimizer
from transformers import GPT2Config, GPT2LMHeadModel from transformers import GPT2Config, GPT2LMHeadModel
def parse_args():
parser = colossalai.get_default_parser()
parser.add_argument(
"--tp_degree",
type=int,
default=1,
help="Tensor Parallelism Degree.",
)
parser.add_argument(
"--placement",
type=str,
default='cpu',
help="Placement Policy for Gemini.",
)
args = parser.parse_args()
return args
## Parameter Sharding Strategies for Tensor Parallelism
def split_param_single_dim_tp1d(dim: int, param: ColoParameter, pg: ProcessGroup):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
if param.process_group.tp_world_size() == 1:
param.set_process_group(pg)
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param: ColoParameter, pg: ProcessGroup):
split_param_single_dim_tp1d(-1, param, pg)
## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel
class GPTLMModel(nn.Module): class GPTLMModel(nn.Module):
def __init__(self, def __init__(self,
...@@ -56,6 +91,7 @@ class GPTLMLoss(nn.Module): ...@@ -56,6 +91,7 @@ class GPTLMLoss(nn.Module):
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
## Randomly Generated Data
def get_data(batch_size, seq_len, vocab_size): def get_data(batch_size, seq_len, vocab_size):
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device()) input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=torch.cuda.current_device())
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
...@@ -90,54 +126,96 @@ def get_tflops(model_numel, batch_size, seq_len, step_time): ...@@ -90,54 +126,96 @@ def get_tflops(model_numel, batch_size, seq_len, step_time):
return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12) return model_numel * batch_size * seq_len * 8 / 1e12 / (step_time + 1e-12)
# Tensor Parallel
def tensor_parallelize(model: torch.nn.Module, pg: ProcessGroup):
"""tensor_parallelize
Sharding the Model Parameters.
Args:
model (torch.nn.Module): a torch module to be sharded
"""
for mn, module in model.named_modules():
for pn, param in module.named_parameters(recurse=False):
# set process group for all parameters
param.set_process_group(pg)
if 'mlp.c_fc' in mn:
if 'weight' in pn or 'bias' in pn:
split_param_col_tp1d(param, pg) # colmn slice
# keep the shape of the output from c_fc
param.compute_spec.set_output_replicate(False)
elif 'mlp.c_proj' in mn:
if 'weight' in pn:
split_param_row_tp1d(param, pg) # row slice
elif 'wte' in mn or 'wpe' in mn:
split_param_col_tp1d(param, pg) # colmn slice
elif 'c_attn' in mn or 'c_proj' in mn:
split_param_col_tp1d(param, pg) # colmn slice
# Gemini + ZeRO DDP
def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: str = "auto"):
cai_version = colossalai.__version__
if version.parse(cai_version) > version.parse("0.1.10"):
from colossalai.nn.parallel import GeminiDDP
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=placememt_policy,
pin_memory=True,
search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(placememt_policy, chunk_manager)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(placememt_policy))
model = ZeroDDP(model, gemini_manager)
else:
raise NotImplemented(f"CAI version {cai_version} is not supported")
return model
def main(): def main():
args = parse_args()
BATCH_SIZE = 8 BATCH_SIZE = 8
SEQ_LEN = 1024 SEQ_LEN = 1024
VOCAB_SIZE = 50257 VOCAB_SIZE = 50257
NUM_STEPS = 10 NUM_STEPS = 10
PLACEMENT_POLICY = 'auto'
disable_existing_loggers() disable_existing_loggers()
colossalai.launch_from_torch(config={}) colossalai.launch_from_torch(config={})
pg = ProcessGroup()
logger = get_dist_logger()
pg = ProcessGroup(tp_degree=args.tp_degree)
logger = get_dist_logger()
logger.info(get_mem_info(), ranks=[0]) logger.info(get_mem_info(), ranks=[0])
# build GPT model # build GPT model
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = gpt2_medium(checkpoint=True) model = gpt2_medium(checkpoint=True)
numel = sum([p.numel() for p in model.parameters()]) numel = sum([p.numel() for p in model.parameters()])
logger.info(f'Model numel: {numel}', ranks=[0]) logger.info(f'Model numel: {numel}', ranks=[0])
get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN) get_tflops_func = partial(get_tflops, numel, BATCH_SIZE, SEQ_LEN)
cai_version = colossalai.__version__ # Tensor Parallelism (TP)
logger.info(f'using Colossal-AI version {cai_version}') tensor_parallelize(model, pg)
if version.parse(cai_version) > version.parse("0.1.10"): # Gemini + ZeRO DP, Note it must be used after TP
from colossalai.nn.parallel import GeminiDDP model = gemini_zero_dpp(model, pg, args.placement)
model = GeminiDDP(model,
device=get_current_device(),
placement_policy=PLACEMENT_POLICY,
pin_memory=True,
search_range_mb=32)
elif version.parse(cai_version) <= version.parse("0.1.10") and version.parse(cai_version) >= version.parse("0.1.9"):
from colossalai.gemini import ChunkManager, GeminiManager
chunk_size = ChunkManager.search_chunk_size(model, 64 * 1024**2, 32)
gemini_manager = GeminiManager(PLACEMENT_POLICY, chunk_manager)
chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=True,
init_device=GeminiManager.get_default_device(PLACEMENT_POLICY))
model = ZeroDDP(model, gemini_manager)
logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
# build criterion # build criterion
criterion = GPTLMLoss() criterion = GPTLMLoss()
# optimizer # build optimizer
optimizer = HybridAdam(model.parameters(), lr=1e-3) optimizer = HybridAdam(model.parameters(), lr=1e-3)
optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5) optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
torch.cuda.synchronize()
model.train() model.train()
for n in range(NUM_STEPS): for n in range(NUM_STEPS):
# we just use randomly generated data here # we just use randomly generated data here
...@@ -156,6 +234,8 @@ def main(): ...@@ -156,6 +234,8 @@ def main():
f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}', f'[{n+1}/{NUM_STEPS}] Loss:{loss.item():.3f}, Step time: {step_time:.3f}s, TFLOPS: {get_tflops_func(step_time):.3f}',
ranks=[0]) ranks=[0])
torch.cuda.synchronize()
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -36,7 +36,6 @@ from datasets import load_dataset ...@@ -36,7 +36,6 @@ from datasets import load_dataset
from packaging import version from packaging import version
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
from utils import colo_memory_cap
import colossalai import colossalai
import transformers import transformers
...@@ -47,7 +46,6 @@ from colossalai.nn.optimizer import HybridAdam ...@@ -47,7 +46,6 @@ from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.checkpoint import load_checkpoint, save_checkpoint
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer from colossalai.zero import ZeroOptimizer
from transformers import ( from transformers import (
...@@ -249,12 +247,20 @@ def parse_args(): ...@@ -249,12 +247,20 @@ def parse_args():
return args return args
def colo_memory_cap(size_in_GB):
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
cuda_capacity = colo_device_memory_capacity(get_current_device())
if size_in_GB * (1024**3) < cuda_capacity:
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
print("Using {} GB of GPU memory".format(size_in_GB))
def main(): def main():
args = parse_args() args = parse_args()
disable_existing_loggers() disable_existing_loggers()
colossalai.launch_from_torch(config=dict()) colossalai.launch_from_torch(config=dict())
logger = get_dist_logger() logger = get_dist_logger()
is_main_process = gpc.get_local_rank(ParallelMode.DATA) == 0 is_main_process = dist.get_rank() == 0
if is_main_process: if is_main_process:
datasets.utils.logging.set_verbosity_warning() datasets.utils.logging.set_verbosity_warning()
......
import torch
import torch.distributed as dist
def memory_cap(size_in_GB):
print(f"use only {size_in_GB} GB of CUDA memory")
assert dist.is_initialized(), "memory_cap must be used after dist init"
local_rank = dist.get_rank()
cuda_capacity = torch.cuda.get_device_properties(local_rank).total_memory
size_in_B = (size_in_GB * 1024**3)
if size_in_B > cuda_capacity:
print(f'memory_cap is uselsess since {cuda_capacity / 1024**3} less than {size_in_GB}')
return
fraction = (size_in_GB * 1024**3) / cuda_capacity
print(f'mem faction is {fraction}')
torch.cuda.set_per_process_memory_fraction(fraction, local_rank)
def colo_memory_cap(size_in_GB):
from colossalai.utils import colo_device_memory_capacity, colo_set_process_memory_fraction, get_current_device
cuda_capacity = colo_device_memory_capacity(get_current_device())
if size_in_GB * (1024**3) < cuda_capacity:
colo_set_process_memory_fraction(size_in_GB * (1024**3) / cuda_capacity)
print("Using {} GB of GPU memory".format(size_in_GB))
if __name__ == '__main__':
memory_cap(40)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment