Unverified Commit b8e770c8 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[test] merge old components to test to model zoo (#4945)

* [test] add custom models in model zoo

* [test] update legacy test

* [test] update model zoo

* [test] update gemini test

* [test] remove components to test
parent 3a41e830
...@@ -9,13 +9,12 @@ from torch.testing import assert_close ...@@ -9,13 +9,12 @@ from torch.testing import assert_close
import colossalai import colossalai
from colossalai.legacy.amp import convert_to_apex_amp from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
PLACEMENT_CONFIGS = [ PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
...@@ -53,12 +52,11 @@ def single_chunk_init(model: torch.nn.Module, placement_config: dict): ...@@ -53,12 +52,11 @@ def single_chunk_init(model: torch.nn.Module, placement_config: dict):
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2"]) @parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("model_init_func", [single_chunk_init, multi_chunk_init]) @parameterize("model_init_func", [single_chunk_init, multi_chunk_init])
def exam_inference(placement_config: dict, model_name: str, model_init_func: Callable): def exam_inference(placement_config: dict, model_name: str, model_init_func: Callable):
set_seed(19360226) set_seed(19360226)
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
torch_model = model_builder().cuda() torch_model = model_builder().cuda()
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128) amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128)
...@@ -79,29 +77,27 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal ...@@ -79,29 +77,27 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal
torch_model.eval() torch_model.eval()
set_seed(dist.get_rank() * 3 + 128) set_seed(dist.get_rank() * 3 + 128)
train_dataloader = iter(train_dataloader) train_dataloader = iter(DummyDataloader(data_gen_fn))
def train_iter(): def train_iter():
input_ids, label = next(train_dataloader) data = next(train_dataloader)
input_ids, label = input_ids.cuda(), label.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
zero_optim.zero_grad() zero_optim.zero_grad()
torch_optim.zero_grad() torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, optimizer=torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) loss = run_fwd_bwd(model, data, output_transform_fn, optimizer=zero_optim)
assert_close(torch_loss, loss, rtol=1e-5, atol=1e-5) assert_close(torch_loss.float(), loss.float(), rtol=1e-5, atol=1e-5)
zero_optim.step() zero_optim.step()
torch_optim.step() torch_optim.step()
check_param(model, torch_model) check_param(model, torch_model)
def inference_iter(): def inference_iter():
input_ids, label = next(train_dataloader) data = next(train_dataloader)
input_ids, label = input_ids.cuda(), label.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
with torch.no_grad(): with torch.no_grad():
torch_output = torch_model(input_ids) torch_loss = run_fwd(torch_model, data, output_transform_fn)
torch_loss = criterion(torch_output.float(), label) zero_loss = run_fwd(model, data, output_transform_fn)
zero_output = model(input_ids) assert_close(torch_loss.float(), zero_loss.float(), rtol=1e-5, atol=1e-5)
zero_loss = criterion(zero_output.float(), label)
assert_close(torch_loss, zero_loss)
train_iter() train_iter()
inference_iter() inference_iter()
......
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from packaging.version import Version
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai import colossalai
from colossalai.legacy.amp import convert_to_apex_amp from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
PLACEMENT_CONFIGS = [ PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2
...@@ -32,14 +30,17 @@ PLACEMENT_CONFIGS = [ ...@@ -32,14 +30,17 @@ PLACEMENT_CONFIGS = [
] ]
# this model is large enough to slice to chunks # this model is large enough to slice to chunks
TEST_MODELS = ["gpt2"] TEST_MODELS = ["transformers_gpt_lm"]
# these models are too small, all parameters in these models are compacted into one chunk # these models are too small, all parameters in these models are compacted into one chunk
EXAMPLE_MODELS = ["albert", "beit", "bert", "hanging_param_model", "nested_model", "repeated_computed_layers"] EXAMPLE_MODELS = [
"transformers_bert_for_sequence_classification",
"custom_hanging_param_model",
"custom_nested_model",
"custom_repeated_computed_layers",
]
# bfloat16 cannot represent them exactly # bfloat16 cannot represent them exactly
BF16_IGNORED_KEYS = [ BF16_IGNORED_KEYS = [
"albert.embeddings.word_embeddings.weight",
"albert.embeddings.position_embeddings.weight",
"masked_bias", "masked_bias",
] ]
...@@ -55,7 +56,7 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty ...@@ -55,7 +56,7 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
temp_zero_value = zero_dict[key].to(device=value.device) temp_zero_value = zero_dict[key].to(device=value.device)
if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS): if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS):
continue continue
rtol, atol = 1e-3, 4e-3 rtol, atol = 2e-3, 6e-3
if dtype is torch.bfloat16: if dtype is torch.bfloat16:
rtol, atol = 4e-3, 8e-3 rtol, atol = 4e-3, 8e-3
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value))) # debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
...@@ -74,8 +75,9 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty ...@@ -74,8 +75,9 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
@parameterize("master_weights", [True, False]) @parameterize("master_weights", [True, False])
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool): def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool):
set_seed(42) set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() iter(model_zoo.get_sub_registry(model_name).values())
)
torch_model = model_builder().cuda() torch_model = model_builder().cuda()
# apex no master weights leads to nan, so we don't use it # apex no master weights leads to nan, so we don't use it
...@@ -104,19 +106,20 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt ...@@ -104,19 +106,20 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
torch_model.eval() torch_model.eval()
set_seed(dist.get_rank() * 3 + 128) set_seed(dist.get_rank() * 3 + 128)
rtol, atol = 1e-4, 1e-5 rtol, atol = 4e-2, 4e-2
for i, (input_ids, label) in enumerate(train_dataloader): train_dataloader = iter(DummyDataloader(data_gen_fn))
for i, data in enumerate(train_dataloader):
if i > 2: if i > 2:
break break
input_ids, label = input_ids.cuda(), label.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
zero_optim.zero_grad() zero_optim.zero_grad()
torch_optim.zero_grad() torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) loss = run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
# as no master weights leads to error accumulation, we don't check the loss # as no master weights leads to error accumulation, we don't check the loss
if master_weights: if master_weights:
assert_close(torch_loss, loss, rtol=rtol, atol=atol) assert_close(torch_loss.float(), loss.float(), rtol=rtol, atol=atol)
zero_optim.step() zero_optim.step()
torch_optim.step() torch_optim.step()
...@@ -125,13 +128,14 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt ...@@ -125,13 +128,14 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
check_param(model, torch_model, mixed_precision) check_param(model, torch_model, mixed_precision)
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", [PLACEMENT_CONFIGS[3]])
@parameterize("model_name", EXAMPLE_MODELS) @parameterize("model_name", EXAMPLE_MODELS)
@parameterize("mixed_precision", [torch.half, torch.bfloat16]) @parameterize("mixed_precision", [torch.half])
def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype): def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype):
set_seed(2008) set_seed(2008)
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() iter(model_zoo.get_sub_registry(model_name).values())
)
torch_model = model_builder().cuda() torch_model = model_builder().cuda()
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=2) amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=2)
...@@ -159,26 +163,19 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch. ...@@ -159,26 +163,19 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.
torch_model.eval() torch_model.eval()
set_seed(dist.get_rank() * 3 + 128) set_seed(dist.get_rank() * 3 + 128)
rtol, atol = 1.5e-6, 2e-5
if mixed_precision is torch.bfloat16:
rtol, atol = 2e-3, 2e-3
elif Version(torch.__version__) >= Version("2.0.0"):
rtol, atol = 4e-5, 3e-5
for i, (input_ids, label) in enumerate(train_dataloader): train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
if i > 2: if i > 2:
break break
input_ids = input_ids.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
label = label.cuda()
zero_optim.zero_grad() zero_optim.zero_grad()
torch_optim.zero_grad() torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12
zero_optim.step() zero_optim.step()
torch_optim.step() torch_optim.step()
......
...@@ -4,10 +4,9 @@ import numpy as np ...@@ -4,10 +4,9 @@ import numpy as np
import pytest import pytest
import torch import torch
from colossalai.testing import clear_cache_before_run from colossalai.testing import DummyDataloader, clear_cache_before_run
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
from tests.components_to_test import run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
@pytest.mark.skip("this is not used") @pytest.mark.skip("this is not used")
...@@ -16,21 +15,22 @@ def test_runtime_mem_tracer(): ...@@ -16,21 +15,22 @@ def test_runtime_mem_tracer():
test_models = ["gpt2", "bert", "simple_net", "repeated_computed_layers", "nested_model", "albert"] test_models = ["gpt2", "bert", "simple_net", "repeated_computed_layers", "nested_model", "albert"]
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, *_ = next(
model_builder, train_dataloader, _, _, criterion = get_components_func() iter(model_zoo.get_sub_registry(model_name).values())
)
model = model_builder(checkpoint=False).cuda() model = model_builder().cuda()
model_bk = deepcopy(model) model_bk = deepcopy(model)
runtime_mem_tracer = RuntimeMemTracer(model) runtime_mem_tracer = RuntimeMemTracer(model)
for i, (data, label) in enumerate(train_dataloader): train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
if i > 1: if i > 1:
break break
data = data.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
label = label.cuda()
run_fwd_bwd(runtime_mem_tracer, data, label, criterion, optimizer=runtime_mem_tracer) run_fwd_bwd(runtime_mem_tracer, data, output_transform_fn, optimizer=runtime_mem_tracer)
for p1, p2 in zip(model_bk.parameters(), model.parameters()): for p1, p2 in zip(model_bk.parameters(), model.parameters()):
torch.allclose(p1.to(torch.half), p2) torch.allclose(p1.to(torch.half), p2)
......
...@@ -5,40 +5,37 @@ import colossalai ...@@ -5,40 +5,37 @@ import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
def exam_search_chunk_size(): def exam_search_chunk_size():
world_size = torch.distributed.get_world_size() model_builder, data_gen_fn, output_transform_fn, *_ = next(
iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())
get_components_func = non_distributed_component_funcs.get_callable("gpt2") )
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
# make sure torch_model and model has the same parameter values # make sure torch_model and model has the same parameter values
model = model_builder() model = model_builder()
config_dict, *_ = search_chunk_configuration( config_dict, *_ = search_chunk_configuration(
model, search_range_m=1, search_interval=16, min_chunk_size_m=0, filter_exlarge_params=True model, search_range_m=1, search_interval=128, min_chunk_size_m=0, filter_exlarge_params=True
) )
for key in config_dict: for key in config_dict:
chunk_size = config_dict[key]["chunk_size"] chunk_size = config_dict[key]["chunk_size"]
if world_size == 1 or True: assert chunk_size == 527872
assert chunk_size == 31616
else:
assert chunk_size == 1024
def exam_chunk_manager(): def exam_chunk_manager():
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
get_components_func = non_distributed_component_funcs.get_callable("gpt2") model_builder, data_gen_fn, output_transform_fn, *_ = next(
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())
)
sharded_ddp_model = model_builder() sharded_ddp_model = model_builder()
chunk_manager = init_chunk_manager( chunk_manager = init_chunk_manager(
sharded_ddp_model, sharded_ddp_model,
get_current_device(), get_current_device(),
hidden_dim=16, hidden_dim=128,
search_range_m=1, search_range_m=1,
min_chunk_size_m=0, min_chunk_size_m=0,
filter_exlarge_params=True, filter_exlarge_params=True,
...@@ -46,7 +43,7 @@ def exam_chunk_manager(): ...@@ -46,7 +43,7 @@ def exam_chunk_manager():
) )
config_dict = chunk_manager.dp_degree_chunk_size_dict config_dict = chunk_manager.dp_degree_chunk_size_dict
assert len(config_dict) == 1 assert len(config_dict) == 1
assert config_dict[world_size] == 31616 assert config_dict[world_size] == 527872
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
......
...@@ -7,7 +7,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn ...@@ -7,7 +7,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP from colossalai.zero import GeminiDDP
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
PLACEMENT_CONFIGS = [ PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
...@@ -26,15 +26,16 @@ def ignore_the_first_parameter(model: torch.nn.Module): ...@@ -26,15 +26,16 @@ def ignore_the_first_parameter(model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [True, False]) @parameterize("keep_gathered", [True, False])
@parameterize("model_name", ["gpt2", "bert"]) @parameterize("model_name", ["transformers_gpt_lm", "transformers_bert_for_sequence_classification"])
@parameterize("master_weights", [False, True]) @parameterize("master_weights", [False, True])
def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool): def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
set_seed(431) set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder() model = model_builder()
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
torch_model = model_builder() torch_model = model_builder()
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data) torch_p.data.copy_(p.data)
...@@ -54,29 +55,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str, master_wei ...@@ -54,29 +55,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str, master_wei
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
# check load state dict
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [True, False])
@parameterize("model_name", ["gpt2", "bert"])
@parameterize("master_weights", [False, True])
def exam_load_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder()
set_seed(451)
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = keep_gathered
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights)
torch_dict = torch_model.state_dict()
model.load_state_dict(torch_dict, strict=False) model.load_state_dict(torch_dict, strict=False)
zero_dict = model.state_dict(only_rank_0=False) zero_dict = model.state_dict(only_rank_0=False)
...@@ -85,23 +64,7 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str, maste ...@@ -85,23 +64,7 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str, maste
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype) temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5) assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
# check state dict shard
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2", "bert"])
@parameterize("master_weights", [False, True])
def exam_state_dict_shard(placement_config, model_name: str, master_weights: bool):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder()
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
model = GeminiDDP(model, config_dict, **placement_config, master_weights=master_weights)
model.train()
zero_dict = model.state_dict(only_rank_0=False)
accumulated_keys = set() accumulated_keys = set()
# ensure number of shards > 1 # ensure number of shards > 1
for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False): for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
...@@ -116,8 +79,6 @@ def run_dist(rank, world_size, port): ...@@ -116,8 +79,6 @@ def run_dist(rank, world_size, port):
config = {} config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_state_dict() exam_state_dict()
exam_load_state_dict()
exam_state_dict_shard()
@pytest.mark.dist @pytest.mark.dist
......
...@@ -8,7 +8,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn ...@@ -8,7 +8,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
PLACEMENT_CONFIGS = [ PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2 {"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2
...@@ -22,8 +22,9 @@ PLACEMENT_CONFIGS = [ ...@@ -22,8 +22,9 @@ PLACEMENT_CONFIGS = [
@parameterize("keep_gathered", [True, False]) @parameterize("keep_gathered", [True, False])
def exam_zero_optim_state_dict(placement_config, keep_gathered): def exam_zero_optim_state_dict(placement_config, keep_gathered):
set_seed(431) set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable("gpt2") model_builder, data_gen_fn, output_transform_fn, *_ = next(
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())
)
model = model_builder() model = model_builder()
...@@ -41,15 +42,15 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered): ...@@ -41,15 +42,15 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered):
set_seed(dist.get_rank() * 3 + 128) set_seed(dist.get_rank() * 3 + 128)
model.train() model.train()
for i, (input_ids, label) in enumerate(train_dataloader): data = data_gen_fn()
if i > 0: data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
break
optim.zero_grad() optim.zero_grad()
logits = model(input_ids) outputs = model(**data)
logits = logits.float() outputs = output_transform_fn(outputs)
loss = criterion(logits, input_ids) loss = next(iter(outputs.values())).sum()
optim.backward(loss) optim.backward(loss)
optim.step() optim.step()
optim_state_dict = optim.state_dict() optim_state_dict = optim.state_dict()
optim.load_state_dict(optim_state_dict) optim.load_state_dict(optim_state_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment