Unverified Commit a7d95b70 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

[example] add zero1, zero2 example in GPT examples (#2146)

* [example] add zero1 and zero2 for GPT

* update readme in gpt example

* polish code

* change init value

* update readme
parent 1cce6e36
...@@ -35,13 +35,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer): ...@@ -35,13 +35,13 @@ class LowLevelZeroOptimizer(ColossalaiOptimizer):
optimizer: Optimizer, optimizer: Optimizer,
# grad scaler config # grad scaler config
initial_scale=2**32, initial_scale=2**16,
min_scale=1, min_scale=1,
growth_factor=2, growth_factor=2,
backoff_factor=0.5, backoff_factor=0.5,
growth_interval=1000, growth_interval=2000,
hysteresis=2, hysteresis=2,
max_scale: int = 2**32, max_scale: int = 2**24,
# grad clipping # grad clipping
clip_grad_norm=0.0, clip_grad_norm=0.0,
......
...@@ -19,10 +19,10 @@ conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit ...@@ -19,10 +19,10 @@ conda install pytorch==1.12.0 torchvision==0.13.0 torchaudio==0.12.0 cudatoolkit
pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113 pip install torch==1.12.0+cu113 torchvision==0.13.0+cu113 torchaudio==0.12.0 --extra-index-url https://download.pytorch.org/whl/cu113
``` ```
### Install [Colossal-AI v0.1.11rc5](https://colossalai.org/download/) From Official Website ### Install [Colossal-AI v0.1.12](https://colossalai.org/download/) From Official Website
```bash ```bash
pip install colossalai==0.1.11rc5+torch1.12cu11.3 -f https://release.colossalai.org pip install colossalai==0.1.12+torch1.12cu11.3 -f https://release.colossalai.org
``` ```
### Install transformers ### Install transformers
...@@ -31,7 +31,8 @@ pip install colossalai==0.1.11rc5+torch1.12cu11.3 -f https://release.colossalai. ...@@ -31,7 +31,8 @@ pip install colossalai==0.1.11rc5+torch1.12cu11.3 -f https://release.colossalai.
pip install transformers pip install transformers
``` ```
This is just an example that we download PyTorch=1.12.0, CUDA=11.6 and colossalai=0.1.11rc5+torch1.12cu11.3. You can download another version of PyTorch and its corresponding ColossalAI version. Just make sure that the version of ColossalAI is at least 0.1.10, PyTorch is at least 1.8.1 and transformers is at least 4.231. This is just an example that we download PyTorch=1.12.0, CUDA=11.6 and colossalai=0.1.12+torch1.12cu11.3. You can download another version of PyTorch and its corresponding ColossalAI version. Just make sure that the version of ColossalAI is at least 0.1.10, PyTorch is at least 1.8.1 and transformers is at least 4.231.
If you want to test ZeRO1 and ZeRO2 in Colossal-AI, you need to ensure Colossal-AI>=0.1.12.
## Dataset ## Dataset
...@@ -48,5 +49,7 @@ bash run.sh ...@@ -48,5 +49,7 @@ bash run.sh
The `train_gpt_demo.py` provides three distributed plans, you can choose the plan you want in `run.sh`. The Colossal-AI leverages Tensor Parallel and Gemini + ZeRO DDP. The `train_gpt_demo.py` provides three distributed plans, you can choose the plan you want in `run.sh`. The Colossal-AI leverages Tensor Parallel and Gemini + ZeRO DDP.
- Colossal-AI - Colossal-AI
- PyTorch DDP - ZeRO1 (Colossal-AI)
- ZeRO - ZeRO2 (Colossal-AI)
\ No newline at end of file - Pytorch DDP
- Pytorch ZeRO
colossalai >= 0.1.10 colossalai >= 0.1.12
torch >= 1.8.1 torch >= 1.8.1
transformers >= 4.231 transformers >= 4.231
# distplan in ["colossalai", "zero", "ddp"] # distplan in ["colossalai", "zero1", "zero2", "torch_ddp", "torch_zero"]
export DISTPAN="colossalai" export DISTPAN="colossalai"
# The following options only valid when DISTPAN="colossalai" # The following options only valid when DISTPAN="colossalai"
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from packaging import version from packaging import version
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import GPT2Config, GPT2LMHeadModel
import colossalai import colossalai
from colossalai.logging import disable_existing_loggers, get_dist_logger from colossalai.logging import disable_existing_loggers, get_dist_logger
...@@ -16,7 +17,7 @@ from colossalai.nn.parallel import ZeroDDP ...@@ -16,7 +17,7 @@ from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec from colossalai.tensor import ColoParameter, ComputePattern, ComputeSpec, ProcessGroup, ReplicaSpec, ShardSpec
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from transformers import GPT2Config, GPT2LMHeadModel from colossalai.zero.sharded_optim import LowLevelZeroOptimizer
def parse_args(): def parse_args():
...@@ -25,7 +26,7 @@ def parse_args(): ...@@ -25,7 +26,7 @@ def parse_args():
"--distplan", "--distplan",
type=str, type=str,
default='colossalai', default='colossalai',
help="The distributed plan [colossalai, ddp, zero].", help="The distributed plan [colossalai, zero1, zero2, torch_ddp, torch_zero].",
) )
parser.add_argument( parser.add_argument(
"--tp_degree", "--tp_degree",
...@@ -202,6 +203,9 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy: ...@@ -202,6 +203,9 @@ def gemini_zero_dpp(model: torch.nn.Module, pg: ProcessGroup, placememt_policy:
def main(): def main():
args = parse_args() args = parse_args()
if args.distplan not in ["colossalai", "torch_ddp", "torch_zero", "zero1", "zero2"]:
raise TypeError(f"{args.distplan} is error")
BATCH_SIZE = 8 BATCH_SIZE = 8
SEQ_LEN = 1024 SEQ_LEN = 1024
VOCAB_SIZE = 50257 VOCAB_SIZE = 50257
...@@ -237,19 +241,24 @@ def main(): ...@@ -237,19 +241,24 @@ def main():
# optimizer = HybridAdam(model.parameters(), lr=1e-3) # optimizer = HybridAdam(model.parameters(), lr=1e-3)
# optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5) # optimizer = ZeroOptimizer(optimizer, model, initial_scale=2**5)
logger.info(get_mem_info(prefix='After init optim, '), ranks=[0]) logger.info(get_mem_info(prefix='After init optim, '), ranks=[0])
else:
elif args.distplan == "ddp":
model = gpt2_medium(checkpoint=True).cuda() model = gpt2_medium(checkpoint=True).cuda()
ddp_model = DDP(model)
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=0.01)
elif args.distplan == "zero": if args.distplan.startswith("torch"):
from torch.distributed.optim import ZeroRedundancyOptimizer model = DDP(model)
model = gpt2_medium(checkpoint=True).cuda() if args.distplan.endswith("ddp"):
ddp_model = DDP(model) optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer = ZeroRedundancyOptimizer(ddp_model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01) elif args.distplan.endswith("zero"):
else: from torch.distributed.optim import ZeroRedundancyOptimizer
raise TypeError(f"{args.distplan} is error") optimizer = ZeroRedundancyOptimizer(model.parameters(), optimizer_class=torch.optim.Adam, lr=0.01)
elif args.distplan.startswith("zero"):
partition_flag = args.distplan == "zero2"
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer = LowLevelZeroOptimizer(optimizer,
overlap_communication=True,
partition_grad=partition_flag,
verbose=True)
# notice that the model is still in fp32
numel = sum([p.numel() for p in model.parameters()]) numel = sum([p.numel() for p in model.parameters()])
logger.info(get_mem_info(prefix='After init model, '), ranks=[0]) logger.info(get_mem_info(prefix='After init model, '), ranks=[0])
...@@ -265,12 +274,13 @@ def main(): ...@@ -265,12 +274,13 @@ def main():
outputs = model(input_ids, attn_mask) outputs = model(input_ids, attn_mask)
loss = criterion(outputs, input_ids) loss = criterion(outputs, input_ids)
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0]) logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Forward '), ranks=[0])
if args.distplan == "colossalai": if args.distplan in ["colossalai", "zero1", "zero2"]:
optimizer.backward(loss) optimizer.backward(loss)
elif args.distplan in ["ddp", "zero"]: elif args.distplan in ["torch_ddp", "torch_zero"]:
loss.backward() loss.backward()
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0]) logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Backward '), ranks=[0])
if args.distplan in ["zero1", "zero2"]:
optimizer.sync_grad()
optimizer.step() optimizer.step()
logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0]) logger.info(get_mem_info(prefix=f'[{n+1}/{NUM_STEPS}] Optimizer step '), ranks=[0])
step_time = time() - start step_time = time() - start
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment