"...coati/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "7bd0bee8ea0ca7d713c34792dc99c1970a2c6701"
Unverified Commit d66e6988 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Merge pull request #5278 from ver217/sync/npu

[sync] sync npu branch with main
parents 9102d655 14846934
...@@ -4,6 +4,7 @@ from types import MethodType ...@@ -4,6 +4,7 @@ from types import MethodType
import pytest import pytest
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import colossalai import colossalai
...@@ -14,21 +15,26 @@ from colossalai.pipeline.stage_manager import PipelineStageManager ...@@ -14,21 +15,26 @@ from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all from colossalai.testing.random import seed_all
DIM = 8
NUM_LAYER = 8
class MlpModel(nn.Module): class MlpModel(nn.Module):
def __init__(self): def __init__(self):
super(MlpModel, self).__init__() super().__init__()
self.linear1 = nn.Linear(4, 8) self.layers = nn.ModuleList([nn.Linear(DIM, DIM) for _ in range(NUM_LAYER)])
self.linear2 = nn.Linear(8, 4)
def forward(self, x): def forward(self, x):
x = self.linear1(x) for layer in self.layers:
x = self.linear2(x) x = layer(x)
return x return x
def pp_linear_fwd( def pp_linear_fwd(
forward, data: torch.Tensor = None, input_obj: torch.Tensor = None, stage_mgr: PipelineStageManager = None forward,
data: torch.Tensor = None,
input_obj: torch.Tensor = None,
stage_mgr: PipelineStageManager = None,
): ):
if stage_mgr.is_first_stage(): if stage_mgr.is_first_stage():
return {"input_obj": forward(data)} return {"input_obj": forward(data)}
...@@ -38,34 +44,45 @@ def pp_linear_fwd( ...@@ -38,34 +44,45 @@ def pp_linear_fwd(
return {"input_obj": forward(input_obj)} return {"input_obj": forward(input_obj)}
def examine_pp(): def examine_pp(num_microbatch: int, batch_size: int):
""" """
This test is to examine the correctness of 1F1B, compared with torch. This test is to examine the correctness of 1F1B, compared with torch.
Be aware it contains some hardcodes. Be aware it contains some hardcodes.
""" """
world_size = torch.distributed.get_world_size() world_size = dist.get_world_size()
local_rank = torch.distributed.get_rank() dist.get_rank()
seed_all(1453) seed_all(1453)
NUM_MICRO_BATCHS = 4
BATCH_SIZE = 4
# create models # create models
torch_model = MlpModel().cuda() torch_model = MlpModel().cuda()
pp_model = copy.deepcopy(torch_model).cuda() pp_model = copy.deepcopy(torch_model).cuda()
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2 pg_mesh = ProcessGroupMesh(world_size)
pg_mesh = ProcessGroupMesh(1, world_size, 1) stage_manager = PipelineStageManager(pg_mesh, pipeline_axis=0)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM) schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=num_microbatch)
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS)
rank = dist.get_rank()
for idx, (_, sub_model) in enumerate(pp_model.named_children()): sharded_model = torch.nn.ModuleList()
if idx % (world_size) == local_rank: num_local_layer = NUM_LAYER // world_size
sharded_model = sub_model.cuda() for idx, sub_model in enumerate(pp_model.layers):
if idx // num_local_layer == rank:
sharded_model.append(sub_model.cuda())
assert len(sharded_model) == num_local_layer
def custom_fwd(self, x):
for layer in self._modules.values():
x = layer(x)
return x
sharded_model._forward = sharded_model.forward sharded_model._forward = MethodType(custom_fwd, sharded_model)
sharded_model.forward = MethodType(partial(pp_linear_fwd, stage_mgr=stage_manager), sharded_model._forward) sharded_model.forward = MethodType(
partial(
pp_linear_fwd,
stage_mgr=stage_manager,
),
sharded_model._forward,
)
# create optimizer # create optimizer
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1) torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
...@@ -73,19 +90,15 @@ def examine_pp(): ...@@ -73,19 +90,15 @@ def examine_pp():
# create # create
seed_all(1453) seed_all(1453)
if stage_manager.is_first_stage(): input_list = [torch.rand(batch_size, DIM).cuda()]
input_list = [torch.rand(BATCH_SIZE, 4).cuda()] dist.all_reduce(input_list[0])
else:
input_list = [torch.zeros(BATCH_SIZE, 4).cuda()]
torch.distributed.all_reduce(input_list[0])
criterion = lambda x, y: torch.mean(x) criterion = lambda x, *arg, **kwargs: (x * x).mean()
# forward and backward # forward and backward
torch_output = torch_model(input_list[0]) torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output, _) torch_loss = criterion(torch_output)
torch_loss.backward() torch_loss.backward()
pp_ret = schedule.forward_backward_step( pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
) )
...@@ -95,34 +108,66 @@ def examine_pp(): ...@@ -95,34 +108,66 @@ def examine_pp():
assert torch.allclose(torch_loss, pp_ret["loss"]) assert torch.allclose(torch_loss, pp_ret["loss"])
# check gradients # check gradients
torch_grad = [] for i in range(len(sharded_model)):
for torch_p in torch_model.parameters(): idx = rank * num_local_layer + i
torch_grad.append(torch_p.grad.data) assert torch.allclose(torch_model.layers[idx].weight.grad, sharded_model[i].weight.grad)
for idx, pp_p in enumerate(sharded_model.parameters()): assert torch.allclose(torch_model.layers[idx].bias.grad, sharded_model[i].bias.grad)
assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data)
# step # step
torch_optimizer.step() torch_optimizer.step()
pp_optimizer.step() pp_optimizer.step()
pp_optimizer.zero_grad()
# check updated param # check updated param
torch_param = [] for i in range(len(sharded_model)):
for torch_p in torch_model.parameters(): idx = rank * num_local_layer + i
torch_param.append(torch_p.data) assert torch.allclose(torch_model.layers[idx].weight, sharded_model[i].weight)
for idx, pp_p in enumerate(sharded_model.parameters()): assert torch.allclose(torch_model.layers[idx].bias, sharded_model[i].bias)
assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data)
# forward only
with torch.no_grad():
def run_dist(rank, world_size, port): torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output)
pp_ret = schedule.forward_backward_step(
sharded_model, iter(input_list), criterion, pp_optimizer, return_loss=True, return_outputs=True
)
if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret["loss"])
for layer in sharded_model:
if layer.weight.grad is None:
assert layer.weight.grad is None and layer.bias.grad is None
else:
assert torch.allclose(layer.weight.grad, torch.zeros_like(layer.weight.grad))
assert torch.allclose(layer.bias.grad, torch.zeros_like(layer.bias.grad))
def run_dist(
rank: int,
world_size: int,
port: int,
num_microbatch: int,
batch_size: int,
):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost") colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host="localhost")
examine_pp() examine_pp(num_microbatch, batch_size)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("num_microbatch", [4, 6])
@pytest.mark.parametrize("batch_size", [12])
@pytest.mark.parametrize("world_size", [2, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_pp(): def test_pp(num_microbatch: int, batch_size: int, world_size: int):
spawn(run_dist, 2) assert NUM_LAYER % world_size == 0
spawn(
run_dist,
world_size,
num_microbatch=num_microbatch,
batch_size=batch_size,
)
if __name__ == "__main__": if __name__ == "__main__":
test_pp() test_pp(num_microbatch=4, batch_size=4, world_size=4)
...@@ -17,23 +17,32 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index): ...@@ -17,23 +17,32 @@ def check_dist_crossentropy(rank, world_size, port, ignore_index):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl") colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
# prepare data # prepare data
pred = torch.randn(2, 4, 8, requires_grad=True) pred = torch.randn(2, 4, 8, requires_grad=True).cuda()
labels = torch.randint(8, (2, 4)) labels = torch.randint(8, (2, 4)).cuda()
# set some label to -100 to test the ignore index # set some label to -100 to test the ignore index
labels[0, -1] = ignore_index labels[0, -1] = ignore_index
org_pred = pred.view(-1, 8) org_pred = pred.view(-1, 8)
org_labels = labels.view(-1) org_labels = labels.view(-1)
org_loss = F.cross_entropy(org_pred, org_labels) org_loss = F.cross_entropy(org_pred, org_labels)
pred.retain_grad()
org_loss.backward()
dist_pred = pred.chunk(world_size, -1)[rank] dist_pred = pred.clone().chunk(world_size, -1)[rank].detach()
dist_loss = cross_entropy_1d(dist_pred.to("cuda"), labels.to("cuda"), ignore_index=ignore_index) dist_pred.requires_grad = True
dist_loss = cross_entropy_1d(dist_pred, labels, ignore_index=ignore_index)
dist_pred.retain_grad()
dist_loss.backward()
assert torch.allclose( assert torch.allclose(
org_loss, dist_loss, atol=1e-5 org_loss, dist_loss, atol=1e-5
), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}" ), f"dist cross entropy loss is not equal to orgin loss\n{org_loss}\n{dist_loss}"
target_grad = torch.chunk(pred.grad, world_size, dim=-1)[rank]
assert torch.allclose(target_grad, dist_pred.grad), f"dist grad is not equal to orgin grad\n{target_grad}\n{dist_pred.grad}"
@pytest.mark.dist @pytest.mark.dist
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_dist_crossentropy(): def test_dist_crossentropy():
......
...@@ -154,7 +154,7 @@ def run_forward_backward_with_hybrid_plugin( ...@@ -154,7 +154,7 @@ def run_forward_backward_with_hybrid_plugin(
data = data_gen_fn() data = data_gen_fn()
if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0: if booster.plugin.shard_config.enable_sequence_parallelism and booster.plugin.tp_size != 0:
seq_len = data["input_ids"].shape[-1] seq_len = data["input_ids"].shape[-1]
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len) lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
times = lcm // seq_len times = lcm // seq_len
...@@ -203,7 +203,7 @@ def check_output_hidden_state( ...@@ -203,7 +203,7 @@ def check_output_hidden_state(
): ):
org_hidden_state = org_output.last_hidden_state org_hidden_state = org_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage(): if stage_manager and stage_manager.is_last_stage(ignore_chunk=True):
sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"] sharded_hidden_state = sharded_output["outputs"]["last_hidden_state"]
else: else:
sharded_hidden_state = sharded_output.last_hidden_state sharded_hidden_state = sharded_output.last_hidden_state
...@@ -229,6 +229,10 @@ def check_weight( ...@@ -229,6 +229,10 @@ def check_weight(
org_weight = getattr_(org_model, suffix).weight org_weight = getattr_(org_model, suffix).weight
sharded_weight = getattr_(sharded_model, suffix).weight sharded_weight = getattr_(sharded_model, suffix).weight
# skip if layer is not held by this process
if sharded_weight is None:
continue
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
sharded_weight_list = [ sharded_weight_list = [
torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group)) torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group))
......
...@@ -37,6 +37,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -37,6 +37,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"] norm_layer_for_check = ["encoder.layer[0].attention.output.LayerNorm", "embeddings.LayerNorm"]
col_layer_for_check = ["encoder.layer[0].output.dense"] col_layer_for_check = ["encoder.layer[0].output.dense"]
row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"] row_layer_for_check = ["embeddings.word_embeddings", "encoder.layer[0].intermediate.dense"]
weight_layer_for_check = ["encoder.layer[0].output.dense", "encoder.layer[1].output.dense"]
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step. # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {} grads_to_check = {}
...@@ -44,7 +45,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -44,7 +45,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
col_layer_grads = get_grad_tensors_for_check( col_layer_grads = get_grad_tensors_for_check(
bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
) )
...@@ -72,7 +73,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -72,7 +73,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
sharded_optimizer.step() sharded_optimizer.step()
# check last hidden state & loss # check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage(): if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3 atol, rtol = 1e-5, 1e-3
else: else:
...@@ -87,8 +88,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -87,8 +88,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 1e-3 atol, rtol = 5e-3, 1e-3
else: else:
atol, rtol = 5e-3, 5e-3 atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage(): if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False) check_weight(bert, sharded_bert, weight_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1)
# check grads # check grads
check_all_grad_tensors(grads_to_check) check_all_grad_tensors(grads_to_check)
...@@ -183,6 +184,17 @@ def run_bert_test(test_config): ...@@ -183,6 +184,17 @@ def run_bert_test(test_config):
"zero_stage": 1, "zero_stage": 1,
"initial_scale": 1, "initial_scale": 1,
}, },
{
"tp_size": 2,
"pp_size": 2,
"pp_style": "interleaved",
"num_model_chunks": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
], ],
) )
def run_bert_3d_test(test_config): def run_bert_3d_test(test_config):
......
import pytest
import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
falcon = unwrap_model(org_model, "FalconModel", "transformer")
sharded_falcon = unwrap_model(sharded_model, "FalconModel", "transformer")
row_layer_for_check = ["h[0].self_attention.query_key_value", "word_embeddings"]
col_layer_for_check = ["h[0].self_attention.dense"]
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32":
atol, rtol = 1e-6, 1e-5
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(
falcon, sharded_falcon, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
)
col_layer_grads = get_grad_tensors_for_check(
falcon, sharded_falcon, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "FalconModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
atol, rtol = 2e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_weight(falcon, sharded_falcon, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
"use_lazy_init": True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
},
{"tp_size": 4, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
{"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
],
)
def run_falcon_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_falcon")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
],
)
def run_falcon_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_falcon")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def check_falcon(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_falcon_test()
def check_falcon_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_falcon_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_falcon():
spawn(check_falcon, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_falcon_3d():
spawn(check_falcon_3d, 8)
if __name__ == "__main__":
test_falcon()
test_falcon_3d()
...@@ -165,7 +165,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -165,7 +165,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
) )
@clear_cache_before_run() @clear_cache_before_run()
def run_gpt2_test(test_config): def run_gpt2_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
...@@ -200,7 +200,7 @@ def run_gpt2_test(test_config): ...@@ -200,7 +200,7 @@ def run_gpt2_test(test_config):
) )
@clear_cache_before_run() @clear_cache_before_run()
def run_gpt2_3d_test(test_config): def run_gpt2_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
......
import pytest
import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
gptj = unwrap_model(org_model, "GPTJModel", "transformer")
sharded_gptj = unwrap_model(sharded_model, "GPTJModel", "transformer")
col_layer_for_check = ["h[0].attn.k_proj"]
row_layer_for_check = ["h[0].mlp.fc_out"] # use dim=0 for wte get_grad_tensors_for_check
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
col_layer_grads = get_grad_tensors_for_check(
gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
)
row_layer_grads = get_grad_tensors_for_check(
gptj, sharded_gptj, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "GPTJModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
atol, rtol = 5e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_weight(gptj, sharded_gptj, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
#'use_lazy_init': True, GPTJ currently do not support lazy init; model training has issue even without sharding
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
#'use_lazy_init': True,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 4,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": True,
#'use_lazy_init': True,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
#'use_lazy_init': True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
{
"tp_size": 1,
"pp_size": 2,
"num_microbatches": 2,
"enable_all_optimization": True,
#'use_lazy_init': True,
"zero_stage": 1,
"precision": "fp16",
"initial_scale": 1,
},
],
)
@clear_cache_before_run()
def run_gptj_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp32",
"initial_scale": 1,
},
{
"tp_size": 2,
"pp_size": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"use_lazy_init": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
],
)
@clear_cache_before_run()
def run_gptj_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_gptj")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
def check_gptj(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_gptj_test()
def check_gptj_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_gptj_3d_test()
@pytest.mark.skip("TODO check_gptj has something wrong.")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_gptj():
spawn(check_gptj, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_gptj_3d():
spawn(check_gptj_3d, 8)
if __name__ == "__main__":
test_gptj()
test_gptj_3d()
...@@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -44,7 +44,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step. # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {} grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0: if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-6, 1e-4 atol, rtol = 1e-6, 1e-4
else: else:
...@@ -63,7 +63,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -63,7 +63,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
sharded_optimizer.step() sharded_optimizer.step()
# check last hidden state & loss # check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage(): if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3 atol, rtol = 1e-5, 1e-3
else: else:
...@@ -75,7 +75,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -75,7 +75,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights # check weights
if stage_manager is None or stage_manager.is_first_stage(): if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
if test_config["precision"] == "fp32": if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3 atol, rtol = 1e-4, 1e-3
else: else:
...@@ -179,6 +179,17 @@ def run_llama_test(test_config): ...@@ -179,6 +179,17 @@ def run_llama_test(test_config):
"zero_stage": 1, "zero_stage": 1,
"initial_scale": 1, "initial_scale": 1,
}, },
{
"tp_size": 2,
"pp_size": 2,
"pp_style": "interleaved",
"num_model_chunks": 2,
"num_microbatches": 4,
"enable_all_optimization": False,
"precision": "fp16",
"zero_stage": 1,
"initial_scale": 1,
},
], ],
) )
def run_llama_3d_test(test_config): def run_llama_3d_test(test_config):
......
import os
import pytest
import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
model_fn, loss_fn, test_config
)
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
mistral_model = unwrap_model(org_model, "MistralModel", "model")
shard_mistral_model = unwrap_model(sharded_model, "MistralModel", "model")
row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
col_layer_for_check = ["layers[0].self_attn.o_proj"]
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config["precision"] == "fp32":
atol, rtol = 5e-5, 1e-4
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(
mistral_model,
shard_mistral_model,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False,
)
col_layer_grads = get_grad_tensors_for_check(
mistral_model,
shard_mistral_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == "MistralModel":
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config["precision"] == "fp32":
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_weight(
mistral_model,
shard_mistral_model,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False,
)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
@parameterize(
"test_config",
[
{
"tp_size": 4,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": False,
"precision": "fp32",
},
{
"tp_size": 2,
"pp_size": 1,
"enable_all_optimization": True,
"use_lazy_init": True,
"zero_stage": 2,
"precision": "fp16",
"initial_scale": 1,
},
],
)
def run_mistral_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry("transformers_mistral")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def check_mistral(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_mistral_test()
@pytest.mark.skip("This test should be run on a version of transformers not less than 4.35.2.")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_mistral():
spawn(check_mistral, 4)
if __name__ == "__main__":
test_mistral()
...@@ -86,6 +86,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -86,6 +86,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 2, "num_microbatches": 2,
"enable_metadata_cache": False,
"enable_all_optimization": True, "enable_all_optimization": True,
"use_lazy_init": True, "use_lazy_init": True,
"precision": "fp16", "precision": "fp16",
...@@ -95,6 +96,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -95,6 +96,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1, "tp_size": 1,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_metadata_cache": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp16", "precision": "fp16",
"initial_scale": 1, "initial_scale": 1,
...@@ -110,6 +112,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -110,6 +112,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1, "tp_size": 1,
"pp_size": 4, "pp_size": 4,
"num_microbatches": 4, "num_microbatches": 4,
"enable_metadata_cache": False,
"enable_all_optimization": False, "enable_all_optimization": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
...@@ -128,6 +131,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -128,6 +131,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1, "tp_size": 1,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 2, "num_microbatches": 2,
"enable_metadata_cache": False,
"enable_all_optimization": True, "enable_all_optimization": True,
"use_lazy_init": True, "use_lazy_init": True,
"zero_stage": 1, "zero_stage": 1,
...@@ -159,6 +163,7 @@ def run_t5_test(test_config): ...@@ -159,6 +163,7 @@ def run_t5_test(test_config):
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_metadata_cache": False,
"enable_all_optimization": False, "enable_all_optimization": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
...@@ -168,6 +173,7 @@ def run_t5_test(test_config): ...@@ -168,6 +173,7 @@ def run_t5_test(test_config):
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_metadata_cache": False,
"enable_all_optimization": False, "enable_all_optimization": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp16", "precision": "fp16",
......
...@@ -114,6 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -114,6 +114,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 2, "num_microbatches": 2,
"enable_metadata_cache": False,
"enable_all_optimization": True, "enable_all_optimization": True,
"use_lazy_init": True, "use_lazy_init": True,
"precision": "fp32", "precision": "fp32",
...@@ -123,6 +124,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -123,6 +124,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1, "tp_size": 1,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_metadata_cache": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
"initial_scale": 1, "initial_scale": 1,
...@@ -138,6 +140,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -138,6 +140,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1, "tp_size": 1,
"pp_size": 4, "pp_size": 4,
"num_microbatches": 4, "num_microbatches": 4,
"enable_metadata_cache": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
}, },
...@@ -163,6 +166,7 @@ def run_whisper_test(test_config): ...@@ -163,6 +166,7 @@ def run_whisper_test(test_config):
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 4, "num_microbatches": 4,
"enable_metadata_cache": False,
"enable_all_optimization": False, "enable_all_optimization": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
...@@ -172,6 +176,7 @@ def run_whisper_test(test_config): ...@@ -172,6 +176,7 @@ def run_whisper_test(test_config):
"tp_size": 2, "tp_size": 2,
"pp_size": 2, "pp_size": 2,
"num_microbatches": 2, "num_microbatches": 2,
"enable_metadata_cache": False,
"enable_all_optimization": False, "enable_all_optimization": False,
"use_lazy_init": False, "use_lazy_init": False,
"precision": "fp32", "precision": "fp32",
......
...@@ -16,7 +16,7 @@ from tests.kit.model_zoo import model_zoo ...@@ -16,7 +16,7 @@ from tests.kit.model_zoo import model_zoo
@parameterize("lazy_init", [True, False]) @parameterize("lazy_init", [True, False])
def check_shardformer_with_ddp(lazy_init: bool): def check_shardformer_with_ddp(lazy_init: bool):
sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt", exclude="transformers_gptj")
# create shardformer # create shardformer
# ranks: [0, 1, 2, 3] # ranks: [0, 1, 2, 3]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment