Unverified Commit b8e770c8 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[test] merge old components to test to model zoo (#4945)

* [test] add custom models in model zoo

* [test] update legacy test

* [test] update model zoo

* [test] update gemini test

* [test] remove components to test
parent 3a41e830
......@@ -359,9 +359,9 @@ output_transform_fn = lambda x: x
# define loss funciton
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
)
loss_fn = lambda x: x.loss
loss_fn = lambda x: x["loss"]
config = transformers.BertConfig(
hidden_size=128,
......
......@@ -35,7 +35,7 @@ def data_gen():
output_transform_fn = lambda x: x
# define loss funciton
loss_fn_blip2_model = lambda x: x.loss
loss_fn_blip2_model = lambda x: x["loss"]
config = transformers.Blip2Config()
config.vision_config.patch_size = 14
......
......@@ -69,11 +69,11 @@ output_transform_fn = lambda x: x
# define loss function
loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
)
loss_fn_for_causal_lm = lambda x: x.loss
loss_fn_for_classification = lambda x: x.loss
loss_fn_for_question_answering = lambda x: x.loss
loss_fn_for_causal_lm = lambda x: x["loss"]
loss_fn_for_classification = lambda x: x["loss"]
loss_fn_for_question_answering = lambda x: x["loss"]
config = transformers.BloomConfig(
n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, pad_token_id=50256
......
......@@ -30,9 +30,9 @@ output_transform_fn = lambda x: x
# define loss function
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
)
loss_fn = lambda x: x.loss
loss_fn = lambda x: x["loss"]
config = ChatGLMConfig(
num_layers=2,
......
......@@ -87,13 +87,14 @@ output_transform_fn = lambda x: x
# define loss function
loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
)
loss_fn = lambda x: x.loss
loss_fn = lambda x: x["loss"]
config = transformers.GPT2Config(
n_layer=2,
n_head=4,
n_embd=128,
vocab_size=50258,
attn_pdrop=0,
embd_pdrop=0,
......
......@@ -42,9 +42,9 @@ if HAS_LLAMA:
output_transform_fn = lambda x: x
# function to get the loss
loss_fn = lambda output: output.last_hidden_state.mean()
loss_fn_for_casual_lm = lambda output: output.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = LlamaConfig(
num_hidden_layers=4,
......
......@@ -45,9 +45,9 @@ def data_gen_for_question_answering():
output_transform_fn = lambda x: x
loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
)
loss_fn_for_lm = lambda x: x.loss
loss_fn_for_lm = lambda x: x["loss"]
config = transformers.OPTConfig(
hidden_size=128,
num_hidden_layers=2,
......
......@@ -40,7 +40,7 @@ def data_gen():
output_transform_fn = lambda x: x
# define loss funciton
loss_fn = lambda x: x.iou_scores.mean()
loss_fn = lambda x: x["iou_scores"].mean()
config = transformers.SamConfig()
config.vision_config.num_hidden_layers = 2
......
......@@ -44,9 +44,9 @@ def data_gen_for_t5_model():
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean()
loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean()
loss_fn_for_conditional_generation = lambda x: x.loss
loss_fn_for_t5_model = lambda x: x["last_hidden_state"].mean()
loss_fn_for_encoder_only = lambda x: x["last_hidden_state"].mean()
loss_fn_for_conditional_generation = lambda x: x["loss"]
# define model config
config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0)
......
......@@ -34,9 +34,9 @@ def data_gen_for_masked_image_modeling():
output_transform_fn = lambda x: x
# function to get the loss
loss_fn_for_vit_model = lambda x: x.pooler_output.mean()
loss_fn_for_image_classification = lambda x: x.logits.mean()
loss_fn_for_masked_image_modeling = lambda x: x.loss
loss_fn_for_vit_model = lambda x: x["pooler_output"].mean()
loss_fn_for_image_classification = lambda x: x["logits"].mean()
loss_fn_for_masked_image_modeling = lambda x: x["loss"]
# register the following models
# transformers.ViTModel,
......
......@@ -53,8 +53,8 @@ def data_gen_for_audio_classification():
output_transform_fn = lambda x: x
# define loss funciton
loss_fn = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state))
loss_fn_attr = lambda x: x.loss
loss_fn = lambda x: torch.nn.functional.mse_loss(x["last_hidden_state"], torch.ones_like(x["last_hidden_state"]))
loss_fn_attr = lambda x: x["loss"]
config = transformers.WhisperConfig(
classifier_proj_size=256,
......
......@@ -6,7 +6,7 @@ import torch
import colossalai
from colossalai.legacy.amp import convert_to_apex_amp, convert_to_naive_amp
from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo
def check_equal(a, b):
......@@ -25,13 +25,12 @@ def run_naive_amp():
torch.backends.cudnn.deterministic = True
# create layer
test_models = ["repeated_computed_layers", "nested_model", "resnet18"]
test_models = ["custom_repeated_computed_layers", "custom_nested_model", "torchvision_resnet18"]
for test_name in test_models:
get_component_func = non_distributed_component_funcs.get_callable(test_name)
model_builder, train_dataloader, _, optim_class, _ = get_component_func()
model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(test_name).values()))
# create model
naive_amp_model = model_builder(checkpoint=True).cuda()
naive_amp_model = model_builder().cuda()
apex_amp_model = copy.deepcopy(naive_amp_model)
# create optimizer
......@@ -48,13 +47,12 @@ def run_naive_amp():
apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config)
# create data
data_iter = iter(train_dataloader)
data, label = next(data_iter)
data = data.cuda()
data = data_gen_fn()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
# forward pass
naive_amp_output = naive_amp_model(data)
apex_amp_output = apex_amp_model(data)
naive_amp_output = naive_amp_model(**data)
apex_amp_output = apex_amp_model(**data)
assert_close_loose(naive_amp_output, apex_amp_output)
# backward
......
......@@ -6,7 +6,7 @@ import torch
import colossalai
from colossalai.legacy.amp import convert_to_apex_amp, convert_to_torch_amp
from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo
def run_torch_amp():
......@@ -18,13 +18,12 @@ def run_torch_amp():
torch.backends.cudnn.deterministic = True
# create layer
test_models = ["resnet18", "simple_net"]
test_models = ["torchvision_resnet18", "custom_simple_net"]
for test_name in test_models:
get_component_func = non_distributed_component_funcs.get_callable(test_name)
model_builder, train_dataloader, _, optim_class, _ = get_component_func()
model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(test_name).values()))
# create model
torch_amp_model = model_builder(checkpoint=True).cuda()
torch_amp_model = model_builder().cuda()
apex_amp_model = copy.deepcopy(torch_amp_model)
# create optimizer
......@@ -41,13 +40,12 @@ def run_torch_amp():
apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config)
# create data
data_iter = iter(train_dataloader)
data, label = next(data_iter)
data = data.cuda()
data = data_gen_fn()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
# forward pass
torch_amp_output = torch_amp_model(data)
apex_amp_output = apex_amp_model(data)
torch_amp_output = torch_amp_model(**data)
apex_amp_output = apex_amp_model(**data)
assert_close_loose(torch_amp_output, apex_amp_output)
for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()):
......
import pytest
import torch
import colossalai
from colossalai.legacy.amp import AMP_TYPE
from colossalai.legacy.core import global_context as gpc
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
CONFIG = dict(
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), clip_grad_norm=1.0
......@@ -15,29 +16,29 @@ CONFIG = dict(
@parameterize("amp_mode", [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None])
def run_train(model_name, amp_mode):
# FIXME: test bert
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
train_dataloader = DummyDataloader(data_gen_fn)
criterion = lambda x: x.sum()
gpc.config.fp16["mode"] = amp_mode
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
model = model_builder(checkpoint=False)
model = model_builder()
engine, train_dataloader, *args = colossalai.legacy.initialize(
model=model,
optimizer=optimizer_class(model.parameters(), lr=1e-3),
optimizer=torch.optim.Adam(model.parameters(), lr=1e-3),
criterion=criterion,
train_dataloader=train_dataloader,
)
try:
engine.train()
for data, label in train_dataloader:
for data in train_dataloader:
engine.zero_grad()
data = data.cuda()
label = label.cuda()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
if criterion:
output = engine(data)
loss = engine.criterion(output, label)
output = engine(**data)
loss = engine.criterion(output)
else:
loss = engine(data, label)
loss = engine(**data)
engine.backward(loss)
engine.step()
break
......
......@@ -5,9 +5,9 @@ import colossalai
from colossalai.legacy.amp.amp_type import AMP_TYPE
from colossalai.legacy.trainer import Trainer
from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import MultiTimer
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo
BATCH_SIZE = 4
IMG_SIZE = 32
......@@ -16,12 +16,14 @@ NUM_EPOCHS = 200
CONFIG = dict(fp16=dict(mode=AMP_TYPE.TORCH))
@parameterize("model_name", ["repeated_computed_layers", "resnet18", "nested_model"])
@parameterize("model_name", ["custom_repeated_computed_layers", "torchvision_resnet18", "custom_nested_model"])
def run_trainer(model_name):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
model = model_builder()
optimizer = optimizer_class(model.parameters(), lr=1e-3)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_dataloader = DummyDataloader(data_gen_fn)
test_dataloader = DummyDataloader(data_gen_fn)
criterion = lambda x: x.sum()
engine, train_dataloader, *_ = colossalai.legacy.initialize(
model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader
)
......
......@@ -2,7 +2,7 @@ import torch
from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.testing import clear_cache_before_run, parameterize
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo
def move_some_params_to_cuda(model, torch_model):
......@@ -22,8 +22,7 @@ def check_params_equal(model, torch_model):
@parameterize("nvme_offload_dir", ["./offload", None])
@parameterize("adam_cls", [CPUAdam, HybridAdam])
def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls):
get_components_func = non_distributed_component_funcs.get_callable("simple_net")
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry("custom_simple_net").values()))
model = model_builder()
torch_model = model_builder()
move_some_params_to_cuda(model, torch_model)
......
......@@ -12,8 +12,7 @@ from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
......@@ -38,7 +37,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gather", [False, True])
@parameterize("model_name", ["gpt2", "bert"])
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("use_grad_checkpoint", [False, True])
@parameterize("master_weights", [False, True])
def exam_gpt_fwd_bwd(
......@@ -49,17 +48,22 @@ def exam_gpt_fwd_bwd(
master_weights: bool = True,
):
init_device = get_current_device()
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
set_seed(42)
model = model_builder(use_grad_checkpoint)
model = model_builder()
set_seed(42)
torch_model = model_builder(use_grad_checkpoint).cuda()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
if use_grad_checkpoint:
model.gradient_checkpointing_enable()
torch_model.gradient_checkpointing_enable()
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000
......@@ -77,25 +81,22 @@ def exam_gpt_fwd_bwd(
torch_model = DDP(torch_model, device_ids=[rank])
set_seed(rank)
for i, (input_ids, label) in enumerate(train_dataloader):
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
if i > 0:
break
input_ids, label = input_ids.cuda(), label.cuda()
torch_optim.zero_grad()
zero_optim.zero_grad()
data = data_gen_fn()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
torch_optim.zero_grad()
zero_optim.zero_grad()
# set random seed is same as torch_model.eval()
set_seed(42)
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
set_seed(42)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
# set random seed is same as torch_model.eval()
set_seed(42)
torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
set_seed(42)
loss = run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
assert torch.equal(torch_loss, loss)
assert_close(torch_loss.float(), loss.float())
check_grad(model, torch_model)
check_grad(model, torch_model)
def run_dist(rank, world_size, port):
......
......@@ -3,38 +3,34 @@ import torch
import torch.distributed as dist
import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP
from colossalai.zero.gemini.chunk import search_chunk_configuration
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
# run gemini use the runtime memory tracer
@parameterize("placement_policy", ["auto"])
@parameterize("keep_gather", [False])
@parameterize("model_name", ["repeated_computed_layers", "bert", "albert", "gpt2"])
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
@parameterize("use_grad_checkpoint", [False, True])
def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
model = model_builder(use_grad_checkpoint).cuda()
model = model_builder().cuda()
if use_grad_checkpoint:
model.gradient_checkpointing_enable()
print(f"model_name {model_name}")
runtime_mem_tracer = RuntimeMemTracer(model)
for i, (input_ids, label) in enumerate(train_dataloader):
if i > 0:
break
input_ids, label = input_ids.cuda(), label.cuda()
# mem tracing
if i == 0:
run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer)
runtime_mem_tracer = RuntimeMemTracer(model)
data = data_gen_fn()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
run_fwd_bwd(runtime_mem_tracer, data, output_transform_fn, optimizer=runtime_mem_tracer)
memstats = runtime_mem_tracer.memstats()
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
print("runtime tracer non model data points: ", len(runtime_tracer_non_model_data))
......@@ -62,16 +58,17 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
)
set_seed(dist.get_rank())
for i, (input_ids, label) in enumerate(train_dataloader):
train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
# print(f'iteration {i}')
if i > 4:
break
input_ids, label = input_ids.cuda(), label.cuda()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
set_seed(42)
run_fwd_bwd(model, input_ids, label, criterion, model)
run_fwd_bwd(model, data, output_transform_fn, optimizer=model)
gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list("cuda")
......
......@@ -7,13 +7,12 @@ from torch.testing import assert_close
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo, run_fwd
PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
......@@ -38,7 +37,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
# Compare gradients.
for p0, p1 in zip(model.parameters(), torch_model.parameters()):
assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5)
assert_close(p0, p1.grad, rtol=2e-3, atol=2e-2)
# Release gradient chunks and move them to gradient device.
for grad_chunk, device in zip(grad_chunk_list, device_list):
......@@ -48,21 +47,19 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [False, True])
@parameterize("model_name", ["gpt2", "bert"])
@parameterize("use_grad_checkpoint", [False, True])
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [False, True])
def exam_gemini_grad_acc(
placement_config, keep_gathered: bool, model_name: str, use_grad_checkpoint: bool, master_weights: bool
):
def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str, master_weights: bool):
init_device = get_current_device()
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
set_seed(42)
gemini_model = model_builder(use_grad_checkpoint)
gemini_model = model_builder()
set_seed(42)
torch_model = model_builder(use_grad_checkpoint).cuda()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()):
torch_p.data.copy_(p.data)
......@@ -94,22 +91,23 @@ def exam_gemini_grad_acc(
set_seed(rank)
accum_iter = 4
for i, (input_ids, label) in enumerate(train_dataloader):
train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
delay_unscale = False if (i + 1) % accum_iter == 0 else True
input_ids, label = input_ids.cuda(), label.cuda()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
set_seed(42 + rank)
torch_loss = run_fwd(torch_model, input_ids, label, criterion)
torch_loss = run_fwd(torch_model, data, output_transform_fn, loss_fn)
torch_loss = torch_loss / accum_iter
with amp.scale_loss(torch_loss, torch_optim, delay_unscale=delay_unscale) as scaled_loss:
scaled_loss.backward()
set_seed(42 + rank)
gemini_loss = run_fwd(gemini_model, input_ids, label, criterion)
gemini_loss = run_fwd(gemini_model, data, output_transform_fn, loss_fn)
gemini_loss = gemini_loss / accum_iter
gemini_optim.backward(gemini_loss)
assert torch.allclose(torch_loss, gemini_loss, rtol=1e-3, atol=1e-5)
assert torch.allclose(torch_loss.float(), gemini_loss.float(), rtol=1e-3, atol=1e-5)
check_grad(gemini_model, torch_model)
......
......@@ -7,12 +7,11 @@ from torch.testing import assert_close
import colossalai
from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
PLACEMENT_CONFIGS = [
{
......@@ -51,12 +50,13 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2"])
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [True, False])
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
set_seed(1912)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
torch_model = model_builder().cuda()
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=32)
......@@ -94,21 +94,17 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
torch_model.train()
set_seed(dist.get_rank() * 3 + 128)
for i, (data, label) in enumerate(train_dataloader):
train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
if i > 2:
break
data = data.cuda()
label = label.cuda()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
zero_optim.zero_grad()
torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim)
loss = run_fwd_bwd(model, data, label, criterion, zero_optim)
# as no master weights leads to error accumulation, we don't check the loss
if master_weights:
assert_close(torch_loss, loss)
run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
import apex.amp as apex_amp
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment