Unverified Commit b8e770c8 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[test] merge old components to test to model zoo (#4945)

* [test] add custom models in model zoo

* [test] update legacy test

* [test] update model zoo

* [test] update gemini test

* [test] remove components to test
parent 3a41e830
...@@ -359,9 +359,9 @@ output_transform_fn = lambda x: x ...@@ -359,9 +359,9 @@ output_transform_fn = lambda x: x
# define loss funciton # define loss funciton
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss( loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state) x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
) )
loss_fn = lambda x: x.loss loss_fn = lambda x: x["loss"]
config = transformers.BertConfig( config = transformers.BertConfig(
hidden_size=128, hidden_size=128,
......
...@@ -35,7 +35,7 @@ def data_gen(): ...@@ -35,7 +35,7 @@ def data_gen():
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
# define loss funciton # define loss funciton
loss_fn_blip2_model = lambda x: x.loss loss_fn_blip2_model = lambda x: x["loss"]
config = transformers.Blip2Config() config = transformers.Blip2Config()
config.vision_config.patch_size = 14 config.vision_config.patch_size = 14
......
...@@ -69,11 +69,11 @@ output_transform_fn = lambda x: x ...@@ -69,11 +69,11 @@ output_transform_fn = lambda x: x
# define loss function # define loss function
loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss( loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state) x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
) )
loss_fn_for_causal_lm = lambda x: x.loss loss_fn_for_causal_lm = lambda x: x["loss"]
loss_fn_for_classification = lambda x: x.loss loss_fn_for_classification = lambda x: x["loss"]
loss_fn_for_question_answering = lambda x: x.loss loss_fn_for_question_answering = lambda x: x["loss"]
config = transformers.BloomConfig( config = transformers.BloomConfig(
n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, pad_token_id=50256 n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, pad_token_id=50256
......
...@@ -30,9 +30,9 @@ output_transform_fn = lambda x: x ...@@ -30,9 +30,9 @@ output_transform_fn = lambda x: x
# define loss function # define loss function
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss( loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state) x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
) )
loss_fn = lambda x: x.loss loss_fn = lambda x: x["loss"]
config = ChatGLMConfig( config = ChatGLMConfig(
num_layers=2, num_layers=2,
......
...@@ -87,13 +87,14 @@ output_transform_fn = lambda x: x ...@@ -87,13 +87,14 @@ output_transform_fn = lambda x: x
# define loss function # define loss function
loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss( loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state) x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
) )
loss_fn = lambda x: x.loss loss_fn = lambda x: x["loss"]
config = transformers.GPT2Config( config = transformers.GPT2Config(
n_layer=2, n_layer=2,
n_head=4, n_head=4,
n_embd=128,
vocab_size=50258, vocab_size=50258,
attn_pdrop=0, attn_pdrop=0,
embd_pdrop=0, embd_pdrop=0,
......
...@@ -42,9 +42,9 @@ if HAS_LLAMA: ...@@ -42,9 +42,9 @@ if HAS_LLAMA:
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
# function to get the loss # function to get the loss
loss_fn = lambda output: output.last_hidden_state.mean() loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output.loss loss_fn_for_casual_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output.logits.mean() loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = LlamaConfig( config = LlamaConfig(
num_hidden_layers=4, num_hidden_layers=4,
......
...@@ -45,9 +45,9 @@ def data_gen_for_question_answering(): ...@@ -45,9 +45,9 @@ def data_gen_for_question_answering():
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss( loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state) x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
) )
loss_fn_for_lm = lambda x: x.loss loss_fn_for_lm = lambda x: x["loss"]
config = transformers.OPTConfig( config = transformers.OPTConfig(
hidden_size=128, hidden_size=128,
num_hidden_layers=2, num_hidden_layers=2,
......
...@@ -40,7 +40,7 @@ def data_gen(): ...@@ -40,7 +40,7 @@ def data_gen():
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
# define loss funciton # define loss funciton
loss_fn = lambda x: x.iou_scores.mean() loss_fn = lambda x: x["iou_scores"].mean()
config = transformers.SamConfig() config = transformers.SamConfig()
config.vision_config.num_hidden_layers = 2 config.vision_config.num_hidden_layers = 2
......
...@@ -44,9 +44,9 @@ def data_gen_for_t5_model(): ...@@ -44,9 +44,9 @@ def data_gen_for_t5_model():
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
# define loss function # define loss function
loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean() loss_fn_for_t5_model = lambda x: x["last_hidden_state"].mean()
loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean() loss_fn_for_encoder_only = lambda x: x["last_hidden_state"].mean()
loss_fn_for_conditional_generation = lambda x: x.loss loss_fn_for_conditional_generation = lambda x: x["loss"]
# define model config # define model config
config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0) config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0)
......
...@@ -34,9 +34,9 @@ def data_gen_for_masked_image_modeling(): ...@@ -34,9 +34,9 @@ def data_gen_for_masked_image_modeling():
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
# function to get the loss # function to get the loss
loss_fn_for_vit_model = lambda x: x.pooler_output.mean() loss_fn_for_vit_model = lambda x: x["pooler_output"].mean()
loss_fn_for_image_classification = lambda x: x.logits.mean() loss_fn_for_image_classification = lambda x: x["logits"].mean()
loss_fn_for_masked_image_modeling = lambda x: x.loss loss_fn_for_masked_image_modeling = lambda x: x["loss"]
# register the following models # register the following models
# transformers.ViTModel, # transformers.ViTModel,
......
...@@ -53,8 +53,8 @@ def data_gen_for_audio_classification(): ...@@ -53,8 +53,8 @@ def data_gen_for_audio_classification():
output_transform_fn = lambda x: x output_transform_fn = lambda x: x
# define loss funciton # define loss funciton
loss_fn = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state)) loss_fn = lambda x: torch.nn.functional.mse_loss(x["last_hidden_state"], torch.ones_like(x["last_hidden_state"]))
loss_fn_attr = lambda x: x.loss loss_fn_attr = lambda x: x["loss"]
config = transformers.WhisperConfig( config = transformers.WhisperConfig(
classifier_proj_size=256, classifier_proj_size=256,
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import colossalai import colossalai
from colossalai.legacy.amp import convert_to_apex_amp, convert_to_naive_amp from colossalai.legacy.amp import convert_to_apex_amp, convert_to_naive_amp
from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
def check_equal(a, b): def check_equal(a, b):
...@@ -25,13 +25,12 @@ def run_naive_amp(): ...@@ -25,13 +25,12 @@ def run_naive_amp():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
# create layer # create layer
test_models = ["repeated_computed_layers", "nested_model", "resnet18"] test_models = ["custom_repeated_computed_layers", "custom_nested_model", "torchvision_resnet18"]
for test_name in test_models: for test_name in test_models:
get_component_func = non_distributed_component_funcs.get_callable(test_name) model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(test_name).values()))
model_builder, train_dataloader, _, optim_class, _ = get_component_func()
# create model # create model
naive_amp_model = model_builder(checkpoint=True).cuda() naive_amp_model = model_builder().cuda()
apex_amp_model = copy.deepcopy(naive_amp_model) apex_amp_model = copy.deepcopy(naive_amp_model)
# create optimizer # create optimizer
...@@ -48,13 +47,12 @@ def run_naive_amp(): ...@@ -48,13 +47,12 @@ def run_naive_amp():
apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config)
# create data # create data
data_iter = iter(train_dataloader) data = data_gen_fn()
data, label = next(data_iter) data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
data = data.cuda()
# forward pass # forward pass
naive_amp_output = naive_amp_model(data) naive_amp_output = naive_amp_model(**data)
apex_amp_output = apex_amp_model(data) apex_amp_output = apex_amp_model(**data)
assert_close_loose(naive_amp_output, apex_amp_output) assert_close_loose(naive_amp_output, apex_amp_output)
# backward # backward
......
...@@ -6,7 +6,7 @@ import torch ...@@ -6,7 +6,7 @@ import torch
import colossalai import colossalai
from colossalai.legacy.amp import convert_to_apex_amp, convert_to_torch_amp from colossalai.legacy.amp import convert_to_apex_amp, convert_to_torch_amp
from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
def run_torch_amp(): def run_torch_amp():
...@@ -18,13 +18,12 @@ def run_torch_amp(): ...@@ -18,13 +18,12 @@ def run_torch_amp():
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
# create layer # create layer
test_models = ["resnet18", "simple_net"] test_models = ["torchvision_resnet18", "custom_simple_net"]
for test_name in test_models: for test_name in test_models:
get_component_func = non_distributed_component_funcs.get_callable(test_name) model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(test_name).values()))
model_builder, train_dataloader, _, optim_class, _ = get_component_func()
# create model # create model
torch_amp_model = model_builder(checkpoint=True).cuda() torch_amp_model = model_builder().cuda()
apex_amp_model = copy.deepcopy(torch_amp_model) apex_amp_model = copy.deepcopy(torch_amp_model)
# create optimizer # create optimizer
...@@ -41,13 +40,12 @@ def run_torch_amp(): ...@@ -41,13 +40,12 @@ def run_torch_amp():
apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config) apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config)
# create data # create data
data_iter = iter(train_dataloader) data = data_gen_fn()
data, label = next(data_iter) data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
data = data.cuda()
# forward pass # forward pass
torch_amp_output = torch_amp_model(data) torch_amp_output = torch_amp_model(**data)
apex_amp_output = apex_amp_model(data) apex_amp_output = apex_amp_model(**data)
assert_close_loose(torch_amp_output, apex_amp_output) assert_close_loose(torch_amp_output, apex_amp_output)
for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()): for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()):
......
import pytest import pytest
import torch
import colossalai import colossalai
from colossalai.legacy.amp import AMP_TYPE from colossalai.legacy.amp import AMP_TYPE
from colossalai.legacy.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
CONFIG = dict( CONFIG = dict(
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), clip_grad_norm=1.0 parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), clip_grad_norm=1.0
...@@ -15,29 +16,29 @@ CONFIG = dict( ...@@ -15,29 +16,29 @@ CONFIG = dict(
@parameterize("amp_mode", [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None]) @parameterize("amp_mode", [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None])
def run_train(model_name, amp_mode): def run_train(model_name, amp_mode):
# FIXME: test bert # FIXME: test bert
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
train_dataloader = DummyDataloader(data_gen_fn)
criterion = lambda x: x.sum()
gpc.config.fp16["mode"] = amp_mode gpc.config.fp16["mode"] = amp_mode
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
model = model_builder(checkpoint=False) model = model_builder()
engine, train_dataloader, *args = colossalai.legacy.initialize( engine, train_dataloader, *args = colossalai.legacy.initialize(
model=model, model=model,
optimizer=optimizer_class(model.parameters(), lr=1e-3), optimizer=torch.optim.Adam(model.parameters(), lr=1e-3),
criterion=criterion, criterion=criterion,
train_dataloader=train_dataloader, train_dataloader=train_dataloader,
) )
try: try:
engine.train() engine.train()
for data, label in train_dataloader: for data in train_dataloader:
engine.zero_grad() engine.zero_grad()
data = data.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
label = label.cuda()
if criterion: if criterion:
output = engine(data) output = engine(**data)
loss = engine.criterion(output, label) loss = engine.criterion(output)
else: else:
loss = engine(data, label) loss = engine(**data)
engine.backward(loss) engine.backward(loss)
engine.step() engine.step()
break break
......
...@@ -5,9 +5,9 @@ import colossalai ...@@ -5,9 +5,9 @@ import colossalai
from colossalai.legacy.amp.amp_type import AMP_TYPE from colossalai.legacy.amp.amp_type import AMP_TYPE
from colossalai.legacy.trainer import Trainer from colossalai.legacy.trainer import Trainer
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import MultiTimer from colossalai.utils import MultiTimer
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
BATCH_SIZE = 4 BATCH_SIZE = 4
IMG_SIZE = 32 IMG_SIZE = 32
...@@ -16,12 +16,14 @@ NUM_EPOCHS = 200 ...@@ -16,12 +16,14 @@ NUM_EPOCHS = 200
CONFIG = dict(fp16=dict(mode=AMP_TYPE.TORCH)) CONFIG = dict(fp16=dict(mode=AMP_TYPE.TORCH))
@parameterize("model_name", ["repeated_computed_layers", "resnet18", "nested_model"]) @parameterize("model_name", ["custom_repeated_computed_layers", "torchvision_resnet18", "custom_nested_model"])
def run_trainer(model_name): def run_trainer(model_name):
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder() model = model_builder()
optimizer = optimizer_class(model.parameters(), lr=1e-3) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_dataloader = DummyDataloader(data_gen_fn)
test_dataloader = DummyDataloader(data_gen_fn)
criterion = lambda x: x.sum()
engine, train_dataloader, *_ = colossalai.legacy.initialize( engine, train_dataloader, *_ = colossalai.legacy.initialize(
model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader
) )
......
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
from colossalai.nn.optimizer import CPUAdam, HybridAdam from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.testing import clear_cache_before_run, parameterize from colossalai.testing import clear_cache_before_run, parameterize
from tests.components_to_test.registry import non_distributed_component_funcs from tests.kit.model_zoo import model_zoo
def move_some_params_to_cuda(model, torch_model): def move_some_params_to_cuda(model, torch_model):
...@@ -22,8 +22,7 @@ def check_params_equal(model, torch_model): ...@@ -22,8 +22,7 @@ def check_params_equal(model, torch_model):
@parameterize("nvme_offload_dir", ["./offload", None]) @parameterize("nvme_offload_dir", ["./offload", None])
@parameterize("adam_cls", [CPUAdam, HybridAdam]) @parameterize("adam_cls", [CPUAdam, HybridAdam])
def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls): def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls):
get_components_func = non_distributed_component_funcs.get_callable("simple_net") model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry("custom_simple_net").values()))
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder() model = model_builder()
torch_model = model_builder() torch_model = model_builder()
move_some_params_to_cuda(model, torch_model) move_some_params_to_cuda(model, torch_model)
......
...@@ -12,8 +12,7 @@ from colossalai.utils import set_seed ...@@ -12,8 +12,7 @@ from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
PLACEMENT_CONFIGS = [ PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
...@@ -38,7 +37,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): ...@@ -38,7 +37,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gather", [False, True]) @parameterize("keep_gather", [False, True])
@parameterize("model_name", ["gpt2", "bert"]) @parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("use_grad_checkpoint", [False, True]) @parameterize("use_grad_checkpoint", [False, True])
@parameterize("master_weights", [False, True]) @parameterize("master_weights", [False, True])
def exam_gpt_fwd_bwd( def exam_gpt_fwd_bwd(
...@@ -49,17 +48,22 @@ def exam_gpt_fwd_bwd( ...@@ -49,17 +48,22 @@ def exam_gpt_fwd_bwd(
master_weights: bool = True, master_weights: bool = True,
): ):
init_device = get_current_device() init_device = get_current_device()
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() iter(model_zoo.get_sub_registry(model_name).values())
)
set_seed(42) set_seed(42)
model = model_builder(use_grad_checkpoint) model = model_builder()
set_seed(42) set_seed(42)
torch_model = model_builder(use_grad_checkpoint).cuda() torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()): for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data) torch_p.data.copy_(p.data)
if use_grad_checkpoint:
model.gradient_checkpointing_enable()
torch_model.gradient_checkpointing_enable()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100) config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000 config_dict[world_size]["chunk_size"] = 5000
...@@ -77,25 +81,22 @@ def exam_gpt_fwd_bwd( ...@@ -77,25 +81,22 @@ def exam_gpt_fwd_bwd(
torch_model = DDP(torch_model, device_ids=[rank]) torch_model = DDP(torch_model, device_ids=[rank])
set_seed(rank) set_seed(rank)
for i, (input_ids, label) in enumerate(train_dataloader):
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
if i > 0:
break
input_ids, label = input_ids.cuda(), label.cuda()
torch_optim.zero_grad() data = data_gen_fn()
zero_optim.zero_grad() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
torch_optim.zero_grad()
zero_optim.zero_grad()
# set random seed is same as torch_model.eval() # set random seed is same as torch_model.eval()
set_seed(42) set_seed(42)
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim) torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
set_seed(42) set_seed(42)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim) loss = run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
assert torch.equal(torch_loss, loss) assert_close(torch_loss.float(), loss.float())
check_grad(model, torch_model) check_grad(model, torch_model)
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
......
...@@ -3,38 +3,34 @@ import torch ...@@ -3,38 +3,34 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import colossalai import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP from colossalai.zero import GeminiDDP
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
from tests.components_to_test import run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
# run gemini use the runtime memory tracer # run gemini use the runtime memory tracer
@parameterize("placement_policy", ["auto"]) @parameterize("placement_policy", ["auto"])
@parameterize("keep_gather", [False]) @parameterize("keep_gather", [False])
@parameterize("model_name", ["repeated_computed_layers", "bert", "albert", "gpt2"]) @parameterize("model_name", ["transformers_bert_for_sequence_classification"])
@parameterize("use_grad_checkpoint", [False, True]) @parameterize("use_grad_checkpoint", [False, True])
def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False): def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False):
set_seed(42) set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder(use_grad_checkpoint).cuda() model = model_builder().cuda()
if use_grad_checkpoint:
model.gradient_checkpointing_enable()
print(f"model_name {model_name}") print(f"model_name {model_name}")
runtime_mem_tracer = RuntimeMemTracer(model)
for i, (input_ids, label) in enumerate(train_dataloader):
if i > 0:
break
input_ids, label = input_ids.cuda(), label.cuda()
# mem tracing runtime_mem_tracer = RuntimeMemTracer(model)
if i == 0: data = data_gen_fn()
run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer) data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
run_fwd_bwd(runtime_mem_tracer, data, output_transform_fn, optimizer=runtime_mem_tracer)
memstats = runtime_mem_tracer.memstats() memstats = runtime_mem_tracer.memstats()
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
print("runtime tracer non model data points: ", len(runtime_tracer_non_model_data)) print("runtime tracer non model data points: ", len(runtime_tracer_non_model_data))
...@@ -62,16 +58,17 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_ ...@@ -62,16 +58,17 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
) )
set_seed(dist.get_rank()) set_seed(dist.get_rank())
for i, (input_ids, label) in enumerate(train_dataloader): train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
# you can only test a single fwd + bwd. # you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization. # after bwd param is grad for Gemini, due to the chunk reuse optimization.
# print(f'iteration {i}') # print(f'iteration {i}')
if i > 4: if i > 4:
break break
input_ids, label = input_ids.cuda(), label.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
set_seed(42) set_seed(42)
run_fwd_bwd(model, input_ids, label, criterion, model) run_fwd_bwd(model, data, output_transform_fn, optimizer=model)
gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list("cuda") gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list("cuda")
......
...@@ -7,13 +7,12 @@ from torch.testing import assert_close ...@@ -7,13 +7,12 @@ from torch.testing import assert_close
import colossalai import colossalai
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd from tests.kit.model_zoo import model_zoo, run_fwd
from tests.components_to_test.registry import non_distributed_component_funcs
PLACEMENT_CONFIGS = [ PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2 {"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
...@@ -38,7 +37,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): ...@@ -38,7 +37,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
# Compare gradients. # Compare gradients.
for p0, p1 in zip(model.parameters(), torch_model.parameters()): for p0, p1 in zip(model.parameters(), torch_model.parameters()):
assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5) assert_close(p0, p1.grad, rtol=2e-3, atol=2e-2)
# Release gradient chunks and move them to gradient device. # Release gradient chunks and move them to gradient device.
for grad_chunk, device in zip(grad_chunk_list, device_list): for grad_chunk, device in zip(grad_chunk_list, device_list):
...@@ -48,21 +47,19 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module): ...@@ -48,21 +47,19 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [False, True]) @parameterize("keep_gathered", [False, True])
@parameterize("model_name", ["gpt2", "bert"]) @parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("use_grad_checkpoint", [False, True])
@parameterize("master_weights", [False, True]) @parameterize("master_weights", [False, True])
def exam_gemini_grad_acc( def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str, master_weights: bool):
placement_config, keep_gathered: bool, model_name: str, use_grad_checkpoint: bool, master_weights: bool
):
init_device = get_current_device() init_device = get_current_device()
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
model_builder, train_dataloader, _, _, criterion = get_components_func() iter(model_zoo.get_sub_registry(model_name).values())
)
set_seed(42) set_seed(42)
gemini_model = model_builder(use_grad_checkpoint) gemini_model = model_builder()
set_seed(42) set_seed(42)
torch_model = model_builder(use_grad_checkpoint).cuda() torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()): for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()):
torch_p.data.copy_(p.data) torch_p.data.copy_(p.data)
...@@ -94,22 +91,23 @@ def exam_gemini_grad_acc( ...@@ -94,22 +91,23 @@ def exam_gemini_grad_acc(
set_seed(rank) set_seed(rank)
accum_iter = 4 accum_iter = 4
for i, (input_ids, label) in enumerate(train_dataloader): train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
delay_unscale = False if (i + 1) % accum_iter == 0 else True delay_unscale = False if (i + 1) % accum_iter == 0 else True
input_ids, label = input_ids.cuda(), label.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
set_seed(42 + rank) set_seed(42 + rank)
torch_loss = run_fwd(torch_model, input_ids, label, criterion) torch_loss = run_fwd(torch_model, data, output_transform_fn, loss_fn)
torch_loss = torch_loss / accum_iter torch_loss = torch_loss / accum_iter
with amp.scale_loss(torch_loss, torch_optim, delay_unscale=delay_unscale) as scaled_loss: with amp.scale_loss(torch_loss, torch_optim, delay_unscale=delay_unscale) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
set_seed(42 + rank) set_seed(42 + rank)
gemini_loss = run_fwd(gemini_model, input_ids, label, criterion) gemini_loss = run_fwd(gemini_model, data, output_transform_fn, loss_fn)
gemini_loss = gemini_loss / accum_iter gemini_loss = gemini_loss / accum_iter
gemini_optim.backward(gemini_loss) gemini_optim.backward(gemini_loss)
assert torch.allclose(torch_loss, gemini_loss, rtol=1e-3, atol=1e-5) assert torch.allclose(torch_loss.float(), gemini_loss.float(), rtol=1e-3, atol=1e-5)
check_grad(gemini_model, torch_model) check_grad(gemini_model, torch_model)
......
...@@ -7,12 +7,11 @@ from torch.testing import assert_close ...@@ -7,12 +7,11 @@ from torch.testing import assert_close
import colossalai import colossalai
from colossalai.legacy.amp import convert_to_apex_amp from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd from tests.kit.model_zoo import model_zoo, run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
PLACEMENT_CONFIGS = [ PLACEMENT_CONFIGS = [
{ {
...@@ -51,12 +50,13 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module): ...@@ -51,12 +50,13 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS) @parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2"]) @parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [True, False]) @parameterize("master_weights", [True, False])
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
set_seed(1912) set_seed(1912)
get_components_func = non_distributed_component_funcs.get_callable(model_name) model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() iter(model_zoo.get_sub_registry(model_name).values())
)
torch_model = model_builder().cuda() torch_model = model_builder().cuda()
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=32) amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=32)
...@@ -94,21 +94,17 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool): ...@@ -94,21 +94,17 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
torch_model.train() torch_model.train()
set_seed(dist.get_rank() * 3 + 128) set_seed(dist.get_rank() * 3 + 128)
for i, (data, label) in enumerate(train_dataloader): train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
if i > 2: if i > 2:
break break
data = data.cuda() data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
label = label.cuda()
zero_optim.zero_grad() zero_optim.zero_grad()
torch_optim.zero_grad() torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim) run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
loss = run_fwd_bwd(model, data, label, criterion, zero_optim) run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
# as no master weights leads to error accumulation, we don't check the loss
if master_weights:
assert_close(torch_loss, loss)
import apex.amp as apex_amp import apex.amp as apex_amp
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment