Commit 890774b2 authored by Hongxin Liu's avatar Hongxin Liu
Browse files

[shardformer] support lazy init (#4202)

* [shardformer] support lazy init

* [shardformer] linear support lazy init

* [shardformer] embedding support lazy init

* [shardformer] norm support lazy init

* [shardformer] fused linear support lazy init

* [test] update shardformer test layer

* [test] shardformer with lazy init fit ddp

* [lazy] hotfix deepcopy of param

* [shardformer] fix bert policy and update test

* [shardformer] fix bloom policy and update test

* [shardformer] fix opt policy and update test

* [shardformer] fix t5 policy and update test

* [shardformer] fix gpt2 policy and update test

* [shardformer] fix llama policy and update test
parent f3bcc292
......@@ -69,10 +69,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('use_lazy_init', [False, True])
def run_gpt2_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
......
......@@ -72,10 +72,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('use_lazy_init', [False, True])
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
......
......@@ -71,10 +71,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('use_lazy_init', [False, True])
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
......
......@@ -82,10 +82,12 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism):
@parameterize('use_lazy_init', [False, True])
def run_t5_test(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism)
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
use_lazy_init)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()
......
from contextlib import nullcontext
import pytest
import torch
import torch.distributed as dist
......@@ -5,15 +7,15 @@ from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.cluster import DistCoordinator
from colossalai.lazy import LazyInitContext
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
def check_shardformer_with_ddp(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
@parameterize('lazy_init', [True, False])
def check_shardformer_with_ddp(lazy_init: bool):
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
......@@ -41,9 +43,12 @@ def check_shardformer_with_ddp(rank, world_size, port):
shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True)
shardformer = ShardFormer(shard_config=shard_config)
ctx = LazyInitContext() if lazy_init else nullcontext()
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create and shard model
model = model_fn().cuda()
with ctx:
model = model_fn().cuda()
sharded_model, _ = shardformer.optimize(model)
# add ddp
......@@ -65,13 +70,18 @@ def check_shardformer_with_ddp(rank, world_size, port):
torch.cuda.empty_cache()
def run_dist(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_shardformer_with_ddp()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_gpt2():
spawn(check_shardformer_with_ddp, 4)
spawn(run_dist, 4)
if __name__ == "__main__":
test_gpt2()
test_gpt2()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment