Unverified Commit 24cbee0e authored by Boyuan Yao's avatar Boyuan Yao Committed by GitHub
Browse files

[tutorial] modify hands-on of auto activation checkpoint (#1920)

* [sc] SC tutorial for auto checkpoint

* [sc] polish examples

* [sc] polish readme

* [sc] polish readme and help information

* [sc] polish readme and help information

* [sc] modify auto checkpoint benchmark

* [sc] remove imgs
parent ff16773d
...@@ -19,79 +19,38 @@ colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py ...@@ -19,79 +19,38 @@ colossalai run --nproc_per_node 4 auto_parallel_with_resnet.py
## Auto Checkpoint Benchmarking ## Auto Checkpoint Benchmarking
We prepare three demos for you to test the performance of auto checkpoint, the test `demo_resnet50.py` and `demo_gpt2_medium.py` will show you the ability of solver to search checkpoint strategy that could fit in the given budget. We prepare two bechmarks for you to test the performance of auto checkpoint
The first test `auto_ckpt_solver_test.py` will show you the ability of solver to search checkpoint strategy that could fit in the given budget (test on GPT2 Medium and ResNet 50). It will output the benchmark summary and data visualization of peak memory vs. budget memory and relative step time vs. peak memory.
The second test `auto_ckpt_batchsize_test.py` will show you the advantage of fitting larger batchsize training into limited GPU memory with the help of our activation checkpoint solver (test on ResNet152). It will output the benchmark summary.
The usage of the above two test The usage of the above two test
```bash ```bash
python demo_resnet50.py --help # run auto_ckpt_solver_test.py on gpt2 medium
usage: ResNet50 Auto Activation Benchmark [-h] [--batch_size BATCH_SIZE] [--num_steps NUM_STEPS] [--sample_points SAMPLE_POINTS] [--free_memory FREE_MEMORY] python auto_ckpt_solver_test.py --model gpt2
[--start_factor START_FACTOR]
# run auto_ckpt_solver_test.py on resnet50
optional arguments: python auto_ckpt_solver_test.py --model resnet50
-h, --help show this help message and exit
--batch_size BATCH_SIZE # tun auto_ckpt_batchsize_test.py
batch size for benchmark, default 128 python auto_ckpt_batchsize_test.py
--num_steps NUM_STEPS
number of test steps for benchmark, default 5
--sample_points SAMPLE_POINTS
number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15
--free_memory FREE_MEMORY
maximum memory budget in MB for benchmark, default 11000 MB
--start_factor START_FACTOR
start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 4
# run with default settings
python demo_resnet50.py
python demo_gpt2_medium.py --help
usage: GPT2 medium Auto Activation Benchmark [-h] [--batch_size BATCH_SIZE] [--num_steps NUM_STEPS] [--sample_points SAMPLE_POINTS] [--free_memory FREE_MEMORY]
[--start_factor START_FACTOR]
optional arguments:
-h, --help show this help message and exit
--batch_size BATCH_SIZE
batch size for benchmark, default 8
--num_steps NUM_STEPS
number of test steps for benchmark, default 5
--sample_points SAMPLE_POINTS
number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15
--free_memory FREE_MEMORY
maximum memory budget in MB for benchmark, default 56000 MB
--start_factor START_FACTOR
start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 10
# run with default settings
python demo_gpt2_medium.py
``` ```
There are some results for your reference There are some results for your reference
## Auto Checkpoint Solver Test
### ResNet 50 ### ResNet 50
![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/tutorial/resnet50_benchmark.png) ![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/tutorial/resnet50_benchmark.png)
### GPT2 Medium ### GPT2 Medium
![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/tutorial/gpt2_benchmark.png) ![](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/tutorial/gpt2_benchmark.png)
We also prepare the demo `demo_resnet152.py` to manifest the benefit of auto activation with large batch, the usage is listed as follows ## Auto Checkpoint Batch Size Test
```bash
python demo_resnet152.py --help
usage: ResNet152 Auto Activation Through Put Benchmark [-h] [--num_steps NUM_STEPS]
optional arguments:
-h, --help show this help message and exit
--num_steps NUM_STEPS
number of test steps for benchmark, default 5
# run with default settings
python demo_resnet152.py
```
here are some results on our end for your reference
```bash ```bash
===============test summary================ ===============test summary================
batch_size: 512, peak memory: 73314.392 MB, through put: 254.286 images/s batch_size: 512, peak memory: 73314.392 MB, through put: 254.286 images/s
batch_size: 1024, peak memory: 73316.216 MB, through put: 397.608 images/s batch_size: 1024, peak memory: 73316.216 MB, through put: 397.608 images/s
batch_size: 2048, peak memory: 72927.837 MB, through put: 277.429 images/s batch_size: 2048, peak memory: 72927.837 MB, through put: 277.429 images/s
``` ```
The above tests will output the test summary and a plot of the benchmarking results.
...@@ -8,7 +8,7 @@ import numpy as np ...@@ -8,7 +8,7 @@ import numpy as np
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torchvision.models as tm import torchvision.models as tm
from bench_utils import bench from bench_utils import bench, data_gen_resnet
import colossalai import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
...@@ -16,19 +16,14 @@ from colossalai.fx import metainfo_trace, symbolic_trace ...@@ -16,19 +16,14 @@ from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port from colossalai.utils import free_port
def data_gen(batch_size, shape, device='cuda'): def _benchmark(rank, world_size, port):
""" """Auto activation checkpoint batchsize benchmark
Generate random data for benchmarking
"""
data = torch.empty(batch_size, *shape, device=device)
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
return (data,), label
def _resnet152_benchmark(rank, world_size, port, num_steps):
"""Resnet152 benchmark
This benchmark test the through put of Resnet152 with our activation solver given the memory budget of 95% of This benchmark test the through put of Resnet152 with our activation solver given the memory budget of 95% of
maximum GPU memory, and with the batch size of [512, 1024, 2048] maximum GPU memory, and with the batch size of [512, 1024, 2048], you could see that using auto activation
checkpoint with optimality guarantee, we might be able to find better batch size for the model, as larger batch
size means that we are able to use larger portion of GPU FLOPS, while recomputation scheduling with our solver
only result in minor performance drop. So at last we might be able to find better training batch size for our
model (combine with large batch training optimizer such as LAMB).
""" """
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = tm.resnet152() model = tm.resnet152()
...@@ -42,33 +37,23 @@ def _resnet152_benchmark(rank, world_size, port, num_steps): ...@@ -42,33 +37,23 @@ def _resnet152_benchmark(rank, world_size, port, num_steps):
gm.graph = solver.solve() gm.graph = solver.solve()
peak_mem, step_time = bench(gm, peak_mem, step_time = bench(gm,
torch.nn.CrossEntropyLoss(), torch.nn.CrossEntropyLoss(),
partial(data_gen, batch_size=batch_size, shape=(3, 224, 224)), partial(data_gen_resnet, batch_size=batch_size, shape=(3, 224, 224)),
num_steps=num_steps) num_steps=5)
peak_mems.append(peak_mem) peak_mems.append(peak_mem)
through_puts.append(batch_size / step_time * 1.0e3) through_puts.append(batch_size / step_time * 1.0e3)
gm.graph = deepcopy(raw_graph) gm.graph = deepcopy(raw_graph)
# print results # print results
print("===============test summary================") print("===============benchmark summary================")
for batch_size, peak_mem, through_put in zip(batch_sizes, peak_mems, through_puts): for batch_size, peak_mem, through_put in zip(batch_sizes, peak_mems, through_puts):
print(f'batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s') print(f'batch_size: {int(batch_size)}, peak memory: {peak_mem:.3f} MB, through put: {through_put:.3f} images/s')
plt.plot(batch_sizes, through_puts)
plt.xlabel("batch size")
plt.ylabel("through put (images/s)")
plt.title("Resnet152 benchmark")
plt.savefig("resnet152_benchmark.png")
def resnet152_benchmark(num_steps): def auto_activation_checkpoint_batchsize_benchmark():
world_size = 1 world_size = 1
run_func_module = partial(_resnet152_benchmark, world_size=world_size, port=free_port(), num_steps=num_steps) run_func_module = partial(_benchmark, world_size=world_size, port=free_port())
mp.spawn(run_func_module, nprocs=world_size) mp.spawn(run_func_module, nprocs=world_size)
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser("ResNet152 Auto Activation Through Put Benchmark") auto_activation_checkpoint_batchsize_benchmark()
parser.add_argument("--num_steps", type=int, default=5, help="number of test steps for benchmark, default 5")
args = parser.parse_args()
resnet152_benchmark(args.num_steps)
...@@ -6,7 +6,7 @@ import matplotlib.pyplot as plt ...@@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torchvision.models as tm import torchvision.models as tm
from bench_utils import GPTLMLoss, bench_rotor, gpt2_medium from bench_utils import GPTLMLoss, bench_rotor, data_gen_gpt2, data_gen_resnet, gpt2_medium
import colossalai import colossalai
from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor from colossalai.auto_parallel.checkpoint import CheckpointSolverRotor
...@@ -14,34 +14,41 @@ from colossalai.fx import metainfo_trace, symbolic_trace ...@@ -14,34 +14,41 @@ from colossalai.fx import metainfo_trace, symbolic_trace
from colossalai.utils import free_port from colossalai.utils import free_port
def data_gen(batch_size, seq_len, vocab_size, device='cuda:0'): def _benchmark(rank, world_size, port, args):
""" """
Generate random data for benchmarking Auto activation checkpoint solver benchmark, we provide benchmark on two models: gpt2_medium and resnet50.
The benchmark will sample in a range of memory budget for each model and output the benchmark summary and
data visualization of peak memory vs. budget memory and relative step time vs. peak memory.
""" """
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
attention_mask = torch.ones_like(input_ids, device=device)
return (input_ids, attention_mask), attention_mask
def _gpt2_benchmark(rank, world_size, port, batch_size, num_steps, sample_points, free_memory, start_factor):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = gpt2_medium() if args.model == 'resnet50':
model = tm.resnet50()
data_gen = partial(data_gen_resnet, batch_size=128, shape=(3, 224, 224))
gm = symbolic_trace(model)
gm = metainfo_trace(gm, torch.empty(128, 3, 224, 224, device='meta'))
loss = torch.nn.CrossEntropyLoss()
else:
model = gpt2_medium()
data_gen = partial(data_gen_gpt2, batch_size=8, seq_len=1024, vocab_size=50257)
data, mask = data_gen(device='meta')[0]
gm = symbolic_trace(model, meta_args={'input_ids': data, 'attention_mask': mask})
gm = metainfo_trace(gm, data, mask)
loss = GPTLMLoss()
free_memory = 11000 * 1024**2 if args.model == 'resnet50' else 56000 * 1024**2
start_factor = 4 if args.model == 'resnet50' else 10
# trace and benchmark # trace and benchmark
data, mask = data_gen(batch_size, 1024, 50257, device='meta')[0]
gm = symbolic_trace(model, meta_args={'input_ids': data, 'attention_mask': mask})
gm = metainfo_trace(gm, data, mask)
budgets, peak_hist, step_hist = bench_rotor(gm, budgets, peak_hist, step_hist = bench_rotor(gm,
GPTLMLoss(), loss,
partial(data_gen, batch_size=batch_size, seq_len=1024, data_gen,
vocab_size=50257), num_steps=5,
num_steps=num_steps, sample_points=15,
sample_points=sample_points,
free_memory=free_memory, free_memory=free_memory,
start_factor=start_factor) start_factor=start_factor)
# print summary # print summary
print("==============test summary==============") print("==============benchmark summary==============")
for budget, peak, step in zip(budgets, peak_hist, step_hist): for budget, peak, step in zip(budgets, peak_hist, step_hist):
print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS') print(f'memory budget: {budget:.3f} MB, peak memory: {peak:.3f} MB, step time: {step:.3f} MS')
...@@ -65,44 +72,18 @@ def _gpt2_benchmark(rank, world_size, port, batch_size, num_steps, sample_points ...@@ -65,44 +72,18 @@ def _gpt2_benchmark(rank, world_size, port, batch_size, num_steps, sample_points
axs[1].set_ylim(0.8, 1.5) axs[1].set_ylim(0.8, 1.5)
# save plot # save plot
fig.savefig("gpt2_benchmark.png") fig.savefig(f"{args.model}_benchmark.png")
def gpt2_benchmark(batch_size, num_steps, sample_points, free_memory, start_factor): def auto_activation_checkpoint_benchmark(args):
world_size = 1 world_size = 1
run_func_module = partial(_gpt2_benchmark, run_func_module = partial(_benchmark, world_size=world_size, port=free_port(), args=args)
world_size=world_size,
port=free_port(),
batch_size=batch_size,
num_steps=num_steps,
sample_points=sample_points,
free_memory=free_memory,
start_factor=start_factor)
mp.spawn(run_func_module, nprocs=world_size) mp.spawn(run_func_module, nprocs=world_size)
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser("GPT2 medium Auto Activation Benchmark") parser = ArgumentParser("Auto Activation Checkpoint Solver Benchmark")
parser.add_argument("--batch_size", type=int, default=8, help="batch size for benchmark, default 8") parser.add_argument("--model", type=str, default='gpt2', choices=['gpt2', 'resnet50'])
parser.add_argument("--num_steps", type=int, default=5, help="number of test steps for benchmark, default 5")
parser.add_argument(
"--sample_points",
type=int,
default=15,
help=
"number of sample points for benchmark from start memory budget to maximum memory budget (free_memory), default 15"
)
parser.add_argument("--free_memory",
type=int,
default=56000,
help="maximum memory budget in MB for benchmark, default 56000 MB")
parser.add_argument(
"--start_factor",
type=int,
default=10,
help=
"start memory budget factor for benchmark, the start memory budget will be free_memory / start_factor, default 10"
)
args = parser.parse_args() args = parser.parse_args()
gpt2_benchmark(args.batch_size, args.num_steps, args.sample_points, args.free_memory * 1024**2, args.start_factor) auto_activation_checkpoint_benchmark(args)
...@@ -154,3 +154,21 @@ def gpt2_xl(checkpoint=False): ...@@ -154,3 +154,21 @@ def gpt2_xl(checkpoint=False):
def gpt2_6b(checkpoint=False): def gpt2_6b(checkpoint=False):
return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint) return GPTLMModel(hidden_size=4096, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)
def data_gen_gpt2(batch_size, seq_len, vocab_size, device='cuda:0'):
"""
Generate random data for gpt2 benchmarking
"""
input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
attention_mask = torch.ones_like(input_ids, device=device)
return (input_ids, attention_mask), attention_mask
def data_gen_resnet(batch_size, shape, device='cuda:0'):
"""
Generate random data for resnet benchmarking
"""
data = torch.empty(batch_size, *shape, device=device)
label = torch.empty(batch_size, dtype=torch.long, device=device).random_(1000)
return (data,), label
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment