Unverified Commit efba0f44 authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

Merge pull request #4612 from hpcaitech/feature/shardformer

[shardformer] update hybrid parallel plugin and fix bugs
parents ac178ca5 fae6c92e
from colossalai.shardformer.policies.whisper import WhisperPolicy
def test_whisper_pipeline_distribution():
num_test_cases = 8
test_dict = {
'num_encoder_layers': [2, 1, 3, 2, 3, 2, 10, 5],
'num_decoder_layers': [2, 8, 0, 2, 1, 5, 6, 22],
'num_stages': [2, 2, 2, 4, 4, 4, 8, 8],
'decoder_starting_stage': [1, 1, 2, 2, 3, 1, 5, 2]
}
for i in range(num_test_cases):
_, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(test_dict['num_encoder_layers'][i],
test_dict['num_decoder_layers'][i],
test_dict['num_stages'][i])
assert test_dict['decoder_starting_stage'][i] == decoder_starting_stage
def test_whisper_pipeline_layers():
num_test_cases = 4
test_dict = {
'num_encoder_layers': [2, 3, 2, 4],
'num_decoder_layers': [2, 0, 2, 8],
'num_stages': [2, 2, 4, 4],
'layers_per_stage': [[[0, 2], [0, 2]], [[0, 1], [1, 3]], [[0, 1], [1, 2], [0, 1], [1, 2]],
[[0, 4], [0, 3], [3, 6], [6, 8]]]
}
for i in range(num_test_cases):
layers_per_stage, decoder_starting_stage = WhisperPolicy.distribute_whisper_layers(
test_dict['num_encoder_layers'][i], test_dict['num_decoder_layers'][i], test_dict['num_stages'][i])
for stage in range(test_dict['num_stages'][i]):
start_idx, end_idx = test_dict['layers_per_stage'][i][stage]
predicted_start, predicted_end = WhisperPolicy.get_whisper_stage_index(layers_per_stage, stage,
decoder_starting_stage)
assert start_idx == predicted_start
assert end_idx == predicted_end
if __name__ == '__main__':
test_whisper_pipeline_distribution()
test_whisper_pipeline_layers()
import copy
from functools import partial
from types import MethodType
import pytest
import torch
import torch.nn as nn
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import OptimizerWrapper
from colossalai.pipeline.schedule.interleaved_pp import InterleavedSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing.random import seed_all
class MlpModel(nn.Module):
def __init__(self):
super(MlpModel, self).__init__()
self.linear1 = nn.Linear(4, 8)
self.linear2 = nn.Linear(8, 8)
self.linear3 = nn.Linear(8, 8)
self.linear4 = nn.Linear(8, 8)
self.linear5 = nn.Linear(8, 8)
self.linear6 = nn.Linear(8, 8)
self.linear7 = nn.Linear(8, 8)
self.linear8 = nn.Linear(8, 4)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = self.linear4(x)
x = self.linear5(x)
x = self.linear6(x)
x = self.linear7(x)
x = self.linear8(x)
return x
def pp_linear_fwd(forward,
data: torch.Tensor = None,
input_obj: torch.Tensor = None,
stage_mgr: PipelineStageManager = None,
num_chunks: int = None,
model_chunk_id: int = None):
if stage_mgr.is_first_stage() and model_chunk_id == 0:
return {'input_obj': forward(data)}
elif stage_mgr.is_last_stage() and model_chunk_id == num_chunks - 1:
return forward(input_obj)
else:
return {'input_obj': forward(input_obj)}
@parameterize("num_micro_batches", [4, 8, 12])
def examine_pp(num_micro_batches):
"""
This test is to examine the correctness of interleaved 1F1B, compared with torch.
Be aware it contains some hardcodes.
"""
world_size = torch.distributed.get_world_size()
local_rank = torch.distributed.get_rank()
seed_all(1453)
NUM_MICRO_BATCHS = num_micro_batches
BATCH_SIZE = num_micro_batches
NUM_CHUNKS = 2
# create model
torch_model = MlpModel().cuda()
pp_model = copy.deepcopy(torch_model).cuda()
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
pg_mesh = ProcessGroupMesh(1, world_size, 1)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM, is_virtual=True)
schedule = InterleavedSchedule(NUM_MICRO_BATCHS, NUM_CHUNKS, stage_manager)
sharded_model = torch.nn.ModuleList()
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
if idx % (world_size) == local_rank:
sub_model._forward = sub_model.forward
sub_model.forward = MethodType(
partial(pp_linear_fwd,
stage_mgr=stage_manager,
num_chunks=NUM_CHUNKS,
model_chunk_id=len(sharded_model)), sub_model._forward)
sharded_model.append(sub_model.cuda())
# create optimizer
torch_optimizer = torch.optim.SGD(torch_model.parameters(), lr=1)
pp_optimizer = OptimizerWrapper(torch.optim.SGD(sharded_model.parameters(), lr=1))
# create
seed_all(1453)
if local_rank == 0:
input_list = [torch.rand(BATCH_SIZE, 4).cuda()]
else:
input_list = [torch.zeros(BATCH_SIZE, 4).cuda()]
torch.distributed.all_reduce(input_list[0])
criterion = lambda x, y: torch.mean(x)
# forward and backward
torch_output = torch_model(input_list[0])
torch_loss = criterion(torch_output, _)
torch_loss.backward()
pp_ret = schedule.forward_backward_step(sharded_model,
pp_optimizer,
iter(input_list),
criterion,
return_loss=True,
return_outputs=True)
# check loss
if stage_manager.is_last_stage():
assert torch.allclose(torch_loss, pp_ret['loss'])
# check gradients
torch_grad = []
for torch_p in torch_model.parameters():
torch_grad.append(torch_p.grad.data)
for idx, pp_p in enumerate(sharded_model.parameters()):
if idx < 2:
assert torch.allclose(torch_grad[idx + local_rank * 2], pp_p.grad.data)
else:
assert torch.allclose(torch_grad[idx + local_rank * 2 + 6], pp_p.grad.data)
# step
torch_optimizer.step()
pp_optimizer.step()
# check updated param
torch_param = []
for torch_p in torch_model.parameters():
torch_param.append(torch_p.data)
for idx, pp_p in enumerate(sharded_model.parameters()):
if idx < 2:
assert torch.allclose(torch_param[idx + local_rank * 2], pp_p.data)
else:
assert torch.allclose(torch_param[idx + local_rank * 2 + 6], pp_p.data)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
examine_pp()
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_pp():
spawn(run_dist, 4)
if __name__ == '__main__':
test_pp()
......@@ -61,7 +61,7 @@ def examine_pp():
DP_DIM, PP_DIM, TP_DIM = 0, 1, 2
pg_mesh = ProcessGroupMesh(1, world_size, 1)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
schedule = OneForwardOneBackwardSchedule(NUM_MICRO_BATCHS, stage_manager)
schedule = OneForwardOneBackwardSchedule(stage_manager, num_microbatches=NUM_MICRO_BATCHS)
for idx, (_, sub_model) in enumerate(pp_model.named_children()):
if idx % (world_size) == local_rank:
......
......@@ -49,15 +49,6 @@ def check_stage_manager():
next_rank = ranks_in_group[ranks_in_group.index(rank) + 1]
assert stage_manager.get_next_rank() == next_rank
# check virtual stage
stage_manager.set_num_virtual_stages(PP_SIZE * 2)
assert stage_manager.num_virtual_stages == PP_SIZE * 2
stage_manager.set_virtual_stage(stage_manager.stage * 2)
assert stage_manager.virtual_stage == stage_manager.stage * 2
with stage_manager.switch_virtual_stage(stage_manager.stage * 2 + 1):
assert stage_manager.virtual_stage == stage_manager.stage * 2 + 1
assert stage_manager.virtual_stage == stage_manager.stage * 2
# check p2p groups
for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]):
if rank in [prev, cur]:
......
......@@ -53,8 +53,7 @@ def rearrange(tensor: torch.Tensor, dim: int):
return rearanged_tensor
@parameterize('lazy_init', [False, True])
def check_linear_conv_1d_col(lazy_init: bool):
def check_linear_conv_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
......@@ -62,7 +61,9 @@ def check_linear_conv_1d_col(lazy_init: bool):
linear_conv_col = GPT2FusedLinearConv1D_Col.from_native_module(linear_copy,
process_group=None,
gather_output=True,
n_fused=3)
seq_parallel=seq_parallel,
n_fused=3,
overlap=overlap)
assert linear.weight.shape == torch.Size([48, 192])
assert linear.bias.shape == torch.Size([192])
......@@ -76,10 +77,11 @@ def check_linear_conv_1d_col(lazy_init: bool):
linear.load_state_dict(linear_conv_col.state_dict())
# check computation correctness
x = torch.rand(4, 48).cuda()
x = torch.rand(1, 4, 48).cuda()
out = linear(x)
gather_out = linear_conv_col(x)
assert_close(rearrange(out, 1), gather_out)
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
gather_out = linear_conv_col(x_for_shard)
assert_close(rearrange(out, -1), gather_out)
# check backward correctness
out.sum().backward()
......@@ -89,14 +91,16 @@ def check_linear_conv_1d_col(lazy_init: bool):
assert_close(target_grad, linear_conv_col.weight.grad)
@parameterize('lazy_init', [False, True])
def check_linear_conv_1d_row(lazy_init: bool):
def check_linear_conv_1d_row(lazy_init: bool, seq_parallel: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = Conv1D(192, 48).cuda()
with ctx:
linear_copy = Conv1D(192, 48).cuda()
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
linear_row = GPT2FusedLinearConv1D_Row.from_native_module(linear_copy,
process_group=None,
parallel_input=False,
seq_parallel=seq_parallel)
assert linear.weight.shape == torch.Size([48, 192])
assert linear_row.weight.shape == torch.Size([24, 192])
......@@ -109,10 +113,11 @@ def check_linear_conv_1d_row(lazy_init: bool):
linear.load_state_dict(linear_row.state_dict())
# check computation correctness
x = torch.rand(4, 48).cuda()
x = torch.rand(1, 4, 48).cuda()
out = linear(x)
gather_out = linear_row(x)
assert_close(out, gather_out)
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_out, gather_out)
# check backward correctness
out.sum().backward()
......@@ -123,12 +128,19 @@ def check_linear_conv_1d_row(lazy_init: bool):
assert_close(target_grad, linear_row.weight.grad)
@parameterize('lazy_init', [False, True])
@parameterize('seq_parallel', [False, True])
@parameterize('overlap', [True])
def check_gpt2_qkv_fused_linear_1d(lazy_init: bool, seq_parallel: bool, overlap: bool):
check_linear_conv_1d_col(lazy_init, seq_parallel, overlap)
check_linear_conv_1d_row(lazy_init, seq_parallel)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# test for linear conv
check_linear_conv_1d_col()
check_linear_conv_1d_row()
check_gpt2_qkv_fused_linear_1d()
@rerun_if_address_is_in_use()
......
......@@ -12,13 +12,16 @@ from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
@parameterize('lazy_init', [False, True])
def check_linear_1d_col(lazy_init: bool):
def check_linear_1d_col(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(linear_copy, process_group=None, gather_output=True)
linear_col = Linear1D_Col.from_native_module(linear_copy,
process_group=None,
gather_output=True,
seq_parallel=seq_parallel,
overlap=overlap)
# ensure that the parameters are distributed
assert is_distributed_tensor(linear_col.weight)
......@@ -35,10 +38,11 @@ def check_linear_1d_col(lazy_init: bool):
linear_col.load_state_dict(linear.state_dict())
# check computation correctness
x = torch.rand(4, 32).cuda()
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
x_for_shard.requires_grad_(True)
out = linear(x_for_unshard)
......@@ -56,17 +60,21 @@ def check_linear_1d_col(lazy_init: bool):
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk(
x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_unshard_gard, x_for_shard.grad)
@parameterize('lazy_init', [False, True])
def check_linear_1d_row(lazy_init: bool):
def check_linear_1d_row(lazy_init: bool, seq_parallel: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_row = Linear1D_Row.from_native_module(linear_copy, process_group=None, parallel_input=False)
linear_row = Linear1D_Row.from_native_module(linear_copy,
process_group=None,
parallel_input=False,
seq_parallel=seq_parallel)
assert linear_row.weight.shape == torch.Size([128, 16])
assert linear_row.bias.shape == torch.Size([128])
......@@ -77,7 +85,8 @@ def check_linear_1d_row(lazy_init: bool):
linear_row.load_state_dict(linear.state_dict())
# check computation correctness
x = torch.rand(4, 32).cuda()
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
......@@ -86,7 +95,8 @@ def check_linear_1d_row(lazy_init: bool):
# run forward
out = linear(x_for_unshard)
gather_out = linear_row(x_for_shard)
assert_close(out, gather_out)
target_out = out if seq_parallel is False else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_out, gather_out)
# check backward correctness
out.sum().backward()
......@@ -102,8 +112,7 @@ def check_linear_1d_row(lazy_init: bool):
assert_close(x_for_unshard.grad, x_for_shard.grad)
@parameterize('lazy_init', [False, True])
def check_linear_col_plus_row(lazy_init: bool):
def check_linear_col_plus_row(lazy_init: bool, seq_parallel: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear_1 = nn.Linear(32, 128).cuda()
......@@ -112,8 +121,15 @@ def check_linear_col_plus_row(lazy_init: bool):
with ctx:
linear_1_copy = nn.Linear(32, 128).cuda()
linear_2_copy = nn.Linear(128, 32).cuda()
linear_col = Linear1D_Col.from_native_module(linear_1_copy, process_group=None, gather_output=False)
linear_row = Linear1D_Row.from_native_module(linear_2_copy, process_group=None, parallel_input=True)
linear_col = Linear1D_Col.from_native_module(linear_1_copy,
process_group=None,
gather_output=False,
seq_parallel=seq_parallel,
overlap=overlap)
linear_row = Linear1D_Row.from_native_module(linear_2_copy,
process_group=None,
parallel_input=True,
seq_parallel=seq_parallel)
linear_1.load_state_dict(linear_col.state_dict())
linear_col.load_state_dict(linear_1.state_dict())
......@@ -121,16 +137,18 @@ def check_linear_col_plus_row(lazy_init: bool):
linear_row.load_state_dict(linear_2.state_dict())
# check computation correctness
x = torch.rand(4, 32).cuda()
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard = x.expand_as(x.clone()) if seq_parallel is False else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
x_for_shard.requires_grad_(True)
# run forward
unshard_out = linear_2(linear_1(x_for_unshard))
shard_out = linear_row(linear_col(x_for_shard))
assert_close(unshard_out, shard_out)
target_out = unshard_out if seq_parallel is False else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_out, shard_out)
# check backward correctness
unshard_out.sum().backward()
......@@ -143,19 +161,28 @@ def check_linear_col_plus_row(lazy_init: bool):
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
target_unshard_gard = x_for_unshard.grad if seq_parallel is False else torch.chunk(
x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_unshard_gard, x_for_shard.grad)
@parameterize('lazy_init', [False, True])
@parameterize('seq_parallel', [False, True])
@parameterize('overlap', [True])
def run_dist_linear_test(lazy_init, seq_parallel, overlap):
check_linear_1d_col(lazy_init, seq_parallel, overlap)
check_linear_1d_row(lazy_init, seq_parallel)
check_linear_col_plus_row(lazy_init, seq_parallel, overlap)
def run_dist(rank, world_size, port):
def check_dist_linear(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
check_linear_1d_col()
check_linear_1d_row()
check_linear_col_plus_row()
run_dist_linear_test()
@rerun_if_address_is_in_use()
def test_linear():
spawn(run_dist, nprocs=2)
spawn(check_dist_linear, nprocs=2)
if __name__ == '__main__':
......
import copy
import math
from contextlib import nullcontext
from typing import Any, Callable, Dict, List, Optional
......@@ -12,6 +13,7 @@ from torch.optim import Adam, Optimizer
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
......@@ -25,6 +27,7 @@ def build_model(model_fn,
enable_tensor_parallelism=True,
enable_flash_attention=False,
enable_jit_fused=False,
enable_sequence_parallelism=False,
use_lazy_init: bool = False):
# create new model
ctx = LazyInitContext() if use_lazy_init else nullcontext()
......@@ -38,7 +41,8 @@ def build_model(model_fn,
shard_config = ShardConfig(enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
enable_jit_fused=enable_jit_fused,
enable_sequence_parallelism=enable_sequence_parallelism)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
sharded_model, shared_params = shard_former.optimize(model_copy)
......@@ -135,6 +139,16 @@ def run_forward_backward_with_hybrid_plugin(org_model: Module, sharded_model: Mo
return loss
data = data_gen_fn()
if booster.plugin.enable_sequence_parallelism and booster.plugin.tp_size != 0:
seq_len = data['input_ids'].shape[1]
lcm = booster.plugin.tp_size * seq_len // math.gcd(booster.plugin.tp_size, seq_len)
times = lcm // seq_len
input_shape = data['input_ids'].shape
for k, v in data.items():
if v.shape == input_shape:
data[k] = v.repeat(1, times)
sharded_model.train()
if booster.plugin.stage_manager is not None:
for k, v in data.items():
......@@ -177,11 +191,10 @@ def check_output_hidden_state(org_output: Tensor,
org_hidden_state = org_output.last_hidden_state
if stage_manager is None:
sharded_hidden_state = sharded_output.last_hidden_state
if stage_manager and stage_manager.is_last_stage():
sharded_hidden_state = torch.cat([output.last_hidden_state for output in sharded_output['outputs']], dim=dim)
sharded_hidden_state = sharded_output['outputs']['last_hidden_state']
else:
sharded_hidden_state = sharded_output.last_hidden_state
assert torch.allclose(org_hidden_state.float(), sharded_hidden_state.float(), atol=atol, rtol=rtol), \
f"shard model's output hidden state is not equal to origin model's last hidden state\n{org_hidden_state}\n{sharded_hidden_state}"
......@@ -219,6 +232,43 @@ def check_weight(org_model: Module,
f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
def get_grad_tensors_for_check(org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
tp_group: ProcessGroup = None,
dim: int = 0,
atol: float = 1e-5,
rtol: float = 1e-3,
verbose: bool = False,
name: str = None):
grad_to_check = {}
for suffix in layer_suffix:
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
shard_grad = torch.cat(shard_grad_list, dim=dim)
# embedding may be resized when using tensor parallel
if shard_grad.shape[0] > org_grad.shape[0]:
shard_grad = shard_grad[:org_grad.shape[0], :]
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' grad: {org_grad}, {shard_grad}")
grad_to_check[suffix] = {
"org_grad": org_grad.float(),
"shard_grad": shard_grad.float(),
"rtol": rtol,
"atol": atol
}
return grad_to_check
# used by sam/blip2
def check_grad(org_model: Module,
sharded_model: Module,
layer_suffix: List[str],
......@@ -231,7 +281,6 @@ def check_grad(org_model: Module,
org_grad = getattr_(org_model, suffix).weight.grad
shard_grad = getattr_(sharded_model, suffix).weight.grad
shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
dist.all_gather(shard_grad_list, shard_grad, tp_group)
......@@ -246,3 +295,30 @@ def check_grad(org_model: Module,
assert torch.allclose(
org_grad.float(), shard_grad.float(), rtol=rtol, atol=atol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
def unwrap_model(module: Module,
base_model_class_name: Optional[str] = None,
base_model_attribute_name: Optional[str] = None):
if isinstance(module, HybridParallelModule):
module = module.unwrap()
if base_model_class_name is None:
return module
if module.__class__.__name__ == base_model_class_name:
return module
return getattr(module, base_model_attribute_name, None)
def check_all_grad_tensors(check_tensors):
"""
"org_grad": tensor to be compared from the original model
"shard_grad": tensor to be compared from the sharded model
"""
for suffix, check_info in check_tensors.items():
org_grad = check_info["org_grad"]
shard_grad = check_info["shard_grad"]
rtol = check_info["rtol"]
atol = check_info["atol"]
assert torch.allclose(
org_grad, shard_grad, atol=atol, rtol=rtol
), f"error attribute '{suffix}', orgin model grad is not equal to shard model grad\n{org_grad}\n{shard_grad}"
......@@ -10,11 +10,13 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
......@@ -32,42 +34,58 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
output_transform_fn,
criterion,
booster)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'BertModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'BertModel':
bert = org_model
sharded_bert = sharded_model.unwrap()
else:
bert = org_model.bert
sharded_bert = sharded_model.unwrap().bert
bert = unwrap_model(org_model, 'BertModel', 'bert')
sharded_bert = unwrap_model(sharded_model, 'BertModel', 'bert')
col_layer_for_check = ['encoder.layer[0].output.dense']
row_layer_for_check = ['embeddings.word_embeddings', 'encoder.layer[0].intermediate.dense']
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
#check_weight(bert.embeddings.word_embeddings, sharded_bert.embeddings.word_embeddings, tp_group, atol=1e-5, rtol=1e-3)
#check_weight(bert.encoder.layer[0].attention.self.query, sharded_bert.encoder.layer[0].attention.self.query, tp_group, atol=5e-3, rtol=1e-3)
check_grad(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
check_grad(bert, sharded_bert, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
# check weights after optimizer.step()
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
col_layer_grads = get_grad_tensors_for_check(bert,
sharded_bert,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
row_layer_grads = get_grad_tensors_for_check(bert,
sharded_bert,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'BertModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if test_config['precision'] == 'fp32':
atol, rtol = 5e-3, 1e-3
else:
......@@ -75,6 +93,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if stage_manager is None or stage_manager.is_first_stage():
check_weight(bert, sharded_bert, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
......@@ -98,6 +119,29 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_bert_test(test_config):
......@@ -111,12 +155,50 @@ def run_bert_test(test_config):
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_bert_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
def check_bert(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_bert_test()
def check_bert_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_bert_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
......@@ -124,5 +206,13 @@ def test_bert():
spawn(check_bert, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bert_3d():
spawn(check_bert_3d, 8)
if __name__ == "__main__":
test_bert()
test_bert_3d()
......@@ -3,16 +3,19 @@ import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
......@@ -34,6 +37,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
bloom = unwrap_model(org_model, 'BloomModel', 'transformer')
sharded_bloom = unwrap_model(sharded_model, 'BloomModel', 'transformer')
row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
col_layer_for_check = ['h[0].self_attention.dense']
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-5
else:
atol, rtol = 5e-3, 5e-3
row_layer_grads = get_grad_tensors_for_check(bloom,
sharded_bloom,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
col_layer_grads = get_grad_tensors_for_check(bloom,
sharded_bloom,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
......@@ -45,28 +85,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'BloomModel':
bloom = org_model
sharded_bloom = sharded_model.unwrap()
else:
bloom = org_model.transformer
sharded_bloom = sharded_model.unwrap().transformer
# check grad
row_layer_for_check = ['h[0].self_attention.query_key_value', 'word_embeddings']
col_layer_for_check = ['h[0].self_attention.dense']
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-5
else:
atol, rtol = 5e-3, 5e-3
check_grad(bloom, sharded_bloom, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
check_grad(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
......@@ -74,6 +92,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3
check_weight(bloom, sharded_bloom, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
......@@ -97,18 +118,72 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_bloom_test(test_config):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_bloom_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
......@@ -118,6 +193,12 @@ def check_bloom(rank, world_size, port):
run_bloom_test()
def check_bloom_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_bloom_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
......@@ -125,5 +206,13 @@ def test_bloom():
spawn(check_bloom, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_bloom_3d():
spawn(check_bloom_3d, 8)
if __name__ == "__main__":
test_bloom()
test_bloom_3d()
......@@ -4,16 +4,19 @@ from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
......@@ -35,35 +38,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'ChatGLMModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'ChatGLMModel':
chatglm_model = org_model
shard_chatglm_model = sharded_model.unwrap()
else:
chatglm_model = org_model.transformer
shard_chatglm_model = sharded_model.unwrap().transformer
chatglm_model = unwrap_model(org_model, 'ChatGLMModel', 'transformer')
shard_chatglm_model = unwrap_model(sharded_model, 'ChatGLMModel', 'transformer')
# check grad
row_layer_for_check = ['encoder.layers[0].self_attention.query_key_value', 'embedding.word_embeddings']
col_layer_for_check = ['encoder.layers[0].self_attention.dense']
if stage_manager is None or stage_manager.is_first_stage():
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_grad(chatglm_model,
row_layer_grads = get_grad_tensors_for_check(chatglm_model,
shard_chatglm_model,
row_layer_for_check,
tp_group,
......@@ -72,7 +61,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=0,
verbose=False)
check_grad(chatglm_model,
col_layer_grads = get_grad_tensors_for_check(chatglm_model,
shard_chatglm_model,
col_layer_for_check,
tp_group,
......@@ -80,10 +69,26 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol=rtol,
dim=1,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# check weights after optimizer.step()
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'ChatGLMModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol, dim=1)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
......@@ -98,6 +103,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
Randomizer.reset_index()
torch.cuda.empty_cache()
......@@ -121,12 +130,55 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}])
def run_chatglm_test(test_config):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_chatglm_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_chatglm')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
......@@ -142,6 +194,12 @@ def check_chatglm(rank, world_size, port):
run_chatglm_test()
def check_chatglm_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_chatglm_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
......@@ -149,5 +207,13 @@ def test_chatglm():
spawn(check_chatglm, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_chatglm_3d():
spawn(check_chatglm_3d, 8)
if __name__ == "__main__":
test_chatglm()
test_chatglm_3d()
......@@ -3,18 +3,20 @@ import torch
from torch import distributed as dist
import colossalai
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
......@@ -36,6 +38,43 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
gpt2 = unwrap_model(org_model, 'GPT2Model', 'transformer')
sharded_gpt2 = unwrap_model(sharded_model, 'GPT2Model', 'transformer')
col_layer_for_check = ['h[0].mlp.c_fc']
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
col_layer_grads = get_grad_tensors_for_check(gpt2,
sharded_gpt2,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
row_layer_grads = get_grad_tensors_for_check(gpt2,
sharded_gpt2,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
......@@ -48,32 +87,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
def unwrap(module):
if isinstance(module, HybridParallelModule):
module = module.unwrap()
if module.__class__.__name__ == 'GPT2Model':
return module
return module.transformer
# unwrap model
gpt2 = unwrap(org_model)
sharded_gpt2 = unwrap(sharded_model)
col_layer_for_check = ['h[0].mlp.c_fc']
row_layer_for_check = ['wte', 'h[0].mlp.c_proj']
# check grad
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_grad(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
check_grad(gpt2, sharded_gpt2, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 5e-3, 1e-3
......@@ -81,6 +95,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
atol, rtol = 5e-3, 5e-3
check_weight(gpt2, sharded_gpt2, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
Randomizer.reset_index()
torch.cuda.empty_cache()
......@@ -106,12 +124,80 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': True,
'use_lazy_init': True,
'enable_sequence_parallelism': True,
'precision': 'fp32',
}, {
'tp_size': 4,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'enable_sequence_parallelism': True,
'precision': 'fp32',
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
@clear_cache_before_run()
def run_gpt2_test(test_config):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
@clear_cache_before_run()
def run_gpt2_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
......@@ -127,10 +213,13 @@ def check_gpt2(rank, world_size, port):
run_gpt2_test()
# TODO(ver217): fix this
def check_gpt2_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_gpt2_3d_test()
@pytest.mark.skip("this will stuck in CI")
@pytest.mark.skip(reason="This test will hang in CI")
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
......@@ -138,5 +227,13 @@ def test_gpt2():
spawn(check_gpt2, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_gpt2_3d():
spawn(check_gpt2_3d, 8)
if __name__ == "__main__":
test_gpt2()
test_gpt2_3d()
......@@ -6,16 +6,19 @@ from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
......@@ -39,35 +42,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'LlamaModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'LlamaModel':
llama_model = org_model
shard_llama_model = sharded_model.unwrap()
else:
llama_model = org_model.model
shard_llama_model = sharded_model.unwrap().model
llama_model = unwrap_model(org_model, 'LlamaModel', 'model')
shard_llama_model = unwrap_model(sharded_model, 'LlamaModel', 'model')
# check grad
row_layer_for_check = ['layers[0].self_attn.q_proj', 'embed_tokens']
col_layer_for_check = ['layers[0].self_attn.o_proj']
if stage_manager is None or stage_manager.is_first_stage():
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-4
else:
atol, rtol = 5e-3, 5e-3
check_grad(llama_model,
row_layer_grads = get_grad_tensors_for_check(llama_model,
shard_llama_model,
row_layer_for_check,
tp_group,
......@@ -75,7 +64,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol=rtol,
dim=0,
verbose=False)
check_grad(llama_model,
col_layer_grads = get_grad_tensors_for_check(llama_model,
shard_llama_model,
col_layer_for_check,
tp_group,
......@@ -83,10 +72,26 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol=rtol,
dim=1,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# check weights after optimizer.step()
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'LlamaModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
......@@ -101,6 +106,9 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
......@@ -128,19 +136,74 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_llama_test(test_config):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_llama_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
......@@ -150,6 +213,12 @@ def check_llama(rank, world_size, port):
run_llama_test()
def check_llama_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_llama_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
......@@ -157,5 +226,13 @@ def test_llama():
spawn(check_llama, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_llama_3d():
spawn(check_llama_3d, 8)
if __name__ == "__main__":
test_llama()
test_llama_3d()
......@@ -6,16 +6,19 @@ from torch import distributed as dist
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
......@@ -39,34 +42,21 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'OPTModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'OPTModel':
opt_model = org_model
shard_opt_model = sharded_model.unwrap()
else:
opt_model = org_model.model
shard_opt_model = sharded_model.unwrap().model
opt_model = unwrap_model(org_model, 'OPTModel', 'model')
shard_opt_model = unwrap_model(sharded_model, 'OPTModel', 'model')
# check grad
row_layer_for_check = ['decoder.layers[0].self_attn.q_proj', 'decoder.embed_tokens'] # 'decoder.embed_tokens'
col_layer_for_check = ['decoder.layers[0].self_attn.out_proj']
if stage_manager is None or stage_manager.is_first_stage():
# Save gradient tensors for comparison between the original model and the sharded model.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-6, 1e-3
else:
atol, rtol = 3e-2, 3e-2
check_grad(opt_model,
atol, rtol = 4e-2, 4e-2
row_layer_grads = get_grad_tensors_for_check(opt_model,
shard_opt_model,
row_layer_for_check,
tp_group,
......@@ -74,7 +64,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol=rtol,
dim=0,
verbose=False)
check_grad(opt_model,
col_layer_grads = get_grad_tensors_for_check(opt_model,
shard_opt_model,
col_layer_for_check,
tp_group,
......@@ -82,10 +72,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol=rtol,
dim=1,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# check weights after optimizer.step()
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'OPTModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
......@@ -100,6 +105,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
Randomizer.reset_index()
torch.cuda.empty_cache()
......@@ -123,12 +132,62 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_opt_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_opt_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
......@@ -144,6 +203,12 @@ def check_OPTModel(rank, world_size, port):
run_opt_test()
def check_opt_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_opt_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
......@@ -151,5 +216,13 @@ def test_OPTModel():
spawn(check_OPTModel, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_opt_3d():
spawn(check_opt_3d, 8)
if __name__ == '__main__':
test_OPTModel()
test_opt_3d()
import pytest
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.logging import disable_existing_loggers
......@@ -9,11 +10,13 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
......@@ -35,6 +38,32 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwrap model
t5 = unwrap_model(org_model)
sharded_t5 = unwrap_model(sharded_model)
row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
row_layer_grads = get_grad_tensors_for_check(t5,
sharded_t5,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
......@@ -47,30 +76,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
t5 = org_model
sharded_t5 = sharded_model.unwrap()
row_layer_for_check = ['shared', 'encoder.block[0].layer[0].SelfAttention.q']
# check weights and gradients
# check weights
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_grad(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0)
# check weights after optimizer.step()
org_optimizer.step()
sharded_optimizer.step()
if test_config['precision'] == 'fp32':
atol, rtol = 1e-4, 1e-3
atol, rtol = 5e-4, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_weight(t5, sharded_t5, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
......@@ -99,17 +115,36 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
@clear_cache_before_run()
def run_t5_test(test_config):
# TODO(baizhou): add plugin_config for TP+DP after supporting & debugging it
# {'tp_size': 2, 'pp_size': 1, 'enable_fused_normalization': True}
# TODO(baizhou): add test_config for flash attention & jit operator after supporting
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
......@@ -125,12 +160,49 @@ def run_t5_test(test_config):
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp16',
'zero_stage': 1,
'initial_scale': 1,
},
])
def run_t5_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
def check_t5(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_t5_test()
def check_t5_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_t5_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
......@@ -138,5 +210,13 @@ def test_t5():
spawn(check_t5, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_t5_3d():
spawn(check_t5_3d, 8)
if __name__ == "__main__":
test_t5()
test_t5_3d()
......@@ -9,11 +9,13 @@ from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_ad
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_grad,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
unwrap_model,
)
......@@ -35,35 +37,22 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'ViTModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# unwrap model
if org_model.__class__.__name__ == 'ViTModel':
vit_model = org_model
shard_vit_model = sharded_model.unwrap()
else:
vit_model = org_model.vit
shard_vit_model = sharded_model.unwrap().vit
vit_model = unwrap_model(org_model, 'ViTModel', 'vit')
shard_vit_model = unwrap_model(sharded_model, 'ViTModel', 'vit')
# check grad
row_layer_for_check = ['encoder.layer[0].attention.attention.query', 'embeddings.patch_embeddings.projection']
col_layer_for_check = ['encoder.layer[0].attention.output.dense']
if stage_manager is None or stage_manager.is_first_stage():
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if (stage_manager is None or stage_manager.is_first_stage()) and booster.plugin.zero_stage == 0:
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
check_grad(vit_model,
row_layer_grads = get_grad_tensors_for_check(vit_model,
shard_vit_model,
row_layer_for_check,
tp_group,
......@@ -71,7 +60,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol=rtol,
dim=0,
verbose=False)
check_grad(vit_model,
col_layer_grads = get_grad_tensors_for_check(vit_model,
shard_vit_model,
col_layer_for_check,
tp_group,
......@@ -79,10 +68,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
rtol=rtol,
dim=1,
verbose=False)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# check weights after optimizer.step()
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 1e-5, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'ViTModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 5e-3, 1e-3
......@@ -97,9 +101,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
dim=1,
verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
#TODO: num_microbatch size = 2 inf loss
@parameterize('test_config', [{
'tp_size': 2,
'pp_size': 2,
......@@ -120,15 +128,36 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32'
}, {
'tp_size': 2,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'zero_stage': 2,
'precision': 'fp16',
'initial_scale': 1
}, {
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': True,
'use_lazy_init': False,
'zero_stage': 1,
'precision': 'fp16',
'initial_scale': 1
}])
def run_vit_test(test_config):
# TODO(baizhou): add test_config for TP+DP after supporting & debugging it
# TODO(baizhou): fix bug when settign lazy_init for Conv2D Layers in ViT models
# TODO: fix bug when settign lazy_init for Conv2D Layers in ViT models
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
......@@ -137,12 +166,48 @@ def run_vit_test(test_config):
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
])
def run_vit_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
def check_vit(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_vit_test()
def check_vit_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_vit_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
......@@ -150,5 +215,13 @@ def test_vit():
spawn(check_vit, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_vit_3d():
spawn(check_vit_3d, 8)
if __name__ == "__main__":
test_vit()
test_vit_3d()
......@@ -3,6 +3,8 @@ import torch
import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.testing import (
assert_hf_output_close,
clear_cache_before_run,
......@@ -11,55 +13,205 @@ from colossalai.testing import (
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, check_grad, run_forward
from tests.test_shardformer.test_model._utils import (
build_model_from_hybrid_plugin,
check_all_grad_tensors,
check_loss,
check_output_hidden_state,
check_weight,
get_grad_tensors_for_check,
run_forward_backward_with_hybrid_plugin,
)
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output, ignore_keys='past_key_values', atol=1e-5)
# do backward
org_loss.backward()
shard_loss.backward()
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin(
org_model,
sharded_model,
sharded_optimizer,
data_gen_fn,
output_transform_fn,
criterion,
booster)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
# unwarp the model
if org_model.__class__.__name__ == 'WhisperForConditionalGeneration':
whisper = org_model.model
sharded_whisper = sharded_model.model
sharded_whisper = sharded_model.unwrap().model
else:
whisper = org_model
sharded_whisper = sharded_model
sharded_whisper = sharded_model.unwrap()
# check grad
if org_model.__class__.__name__ == 'WhisperForAudioClassification':
col_layer_for_check = ['encoder.layers[0].self_attn.q_proj']
row_layer_for_check = ['encoder.layers[0].self_attn.out_proj']
else:
col_layer_for_check = ['encoder.layers[0].self_attn.q_proj', 'decoder.layers[0].self_attn.q_proj']
row_layer_for_check = ['encoder.layers[0].self_attn.out_proj', 'decoder.layers[0].self_attn.out_proj']
check_grad(whisper, sharded_whisper, col_layer_for_check, atol=1e-6, rtol=1e-5, dim=0, verbose=False)
check_grad(whisper, sharded_whisper, row_layer_for_check, atol=1e-6, rtol=1e-5, dim=1, verbose=False)
col_layer_for_check = [
'encoder.layers[0].self_attn.q_proj',
# 'decoder.layers[0].self_attn.q_proj'
]
row_layer_for_check = [
'encoder.layers[0].self_attn.out_proj',
#'decoder.layers[0].self_attn.out_proj'
]
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
if test_config['precision'] == 'fp32':
atol, rtol = 2e-4, 2e-4
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
row_layer_grads = get_grad_tensors_for_check(whisper,
sharded_whisper,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1)
col_layer_grads = get_grad_tensors_for_check(whisper,
sharded_whisper,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0)
grads_to_check.update(col_layer_grads)
grads_to_check.update(row_layer_grads)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
# check last hidden state & loss
if stage_manager is None or stage_manager.is_last_stage():
if test_config['precision'] == 'fp32':
atol, rtol = 2e-4, 2e-4
else:
atol, rtol = 5e-3, 5e-3
if org_model.__class__.__name__ == 'WhisperModel':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
# check weights
if test_config['precision'] == 'fp32':
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():
check_weight(whisper,
sharded_whisper,
row_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=1,
verbose=False)
check_weight(whisper,
sharded_whisper,
col_layer_for_check,
tp_group,
atol=atol,
rtol=rtol,
dim=0,
verbose=False)
# check grads
check_all_grad_tensors(grads_to_check)
torch.cuda.empty_cache()
@parameterize('enable_fused_normalization', [True, False])
@parameterize('enable_tensor_parallelism', [True, False])
@parameterize('enable_flash_attention', [True, False])
@parameterize('enable_jit_fused', [True, False])
def run_whisper_test(enable_fused_normalization, enable_tensor_parallelism, enable_flash_attention, enable_jit_fused):
#TODO fix WhisperForConditionalGeneration enable jit fused operato
# TODO(jianghai) fix fp16
@parameterize(
'test_config',
[
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': True,
'use_lazy_init': True,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 1,
'pp_size': 2,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 4,
'pp_size': 1,
'enable_all_optimization': True,
'use_lazy_init': False,
'precision': 'fp32',
},
{
'tp_size': 1,
'pp_size': 4,
'num_microbatches': 4,
'use_lazy_init': False,
'precision': 'fp32',
},
# whisper is not supported fp16 for now.
])
def run_whisper_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(model_fn,
enable_fused_normalization=enable_fused_normalization,
enable_tensor_parallelism=enable_tensor_parallelism,
enable_flash_attention=enable_flash_attention,
enable_jit_fused=enable_jit_fused)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
if test_config['pp_size'] > 2 and name == 'transformers_whisper_for_audio_classification':
continue
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
Randomizer.reset_index()
torch.cuda.empty_cache()
@parameterize('test_config', [
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 4,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
{
'tp_size': 2,
'pp_size': 2,
'num_microbatches': 2,
'enable_all_optimization': False,
'use_lazy_init': False,
'precision': 'fp32',
'initial_scale': 1,
},
])
def run_whisper_3d_test(test_config):
sub_model_zoo = model_zoo.get_sub_registry('transformers_whisper')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
clear_layout_converter()
torch.cuda.empty_cache()
......@@ -69,12 +221,26 @@ def check_whisper(rank, world_size, port):
run_whisper_test()
def check_whisper_3d(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_whisper_3d_test()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_whisper():
spawn(check_whisper, 2)
spawn(check_whisper, 4)
@pytest.mark.largedist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_whisper_3d():
spawn(check_whisper_3d, 8)
if __name__ == "__main__":
test_whisper()
test_whisper_3d()
......@@ -40,7 +40,6 @@ def forward_inplace(x, weight):
return out
@pytest.mark.gpu
@clear_cache_before_run()
@parameterize("use_reentrant", [True, False])
@parameterize("cpu_offload", [True, False])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment