"examples/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "bce9499ed33f7a8359bbb568c7ee18d72e8aa731"
Commit 21e0a42f authored by flybird11111's avatar flybird11111 Committed by Hongxin Liu
Browse files

[shardformer]fix, test gpt2 for AMP+TP (#4403)

* [shardformer] gpt2 tests fix

[shardformer] test all optimizations (#4399)

[shardformer] test all optimizations

[shardformer] test all optimizations

[shardformer] test all optimizations

[shardformer] gpt2 tests fix

* [shardformer] gpt2 tests fix
parent 7596e9ae
...@@ -210,7 +210,7 @@ def check_weight(org_model: Module, ...@@ -210,7 +210,7 @@ def check_weight(org_model: Module,
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight): if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
sharded_weight_list = [ sharded_weight_list = [
torch.zeros([*sharded_weight.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group)) torch.zeros_like(sharded_weight).to('cuda') for _ in range(dist.get_world_size(tp_group))
] ]
dist.all_gather(sharded_weight_list, sharded_weight, tp_group) dist.all_gather(sharded_weight_list, sharded_weight, tp_group)
sharded_weight = torch.cat(sharded_weight_list, dim=dim) sharded_weight = torch.cat(sharded_weight_list, dim=dim)
...@@ -219,7 +219,7 @@ def check_weight(org_model: Module, ...@@ -219,7 +219,7 @@ def check_weight(org_model: Module,
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}") print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \ assert torch.allclose(org_weight.float(), sharded_weight.float(), atol=atol, rtol=rtol), \
f"shard model weight is not equal to origin model weight\n{org_weight}\n{sharded_weight}" f"shard model weight {suffix} is not equal to origin model weight\n{org_weight}\n{sharded_weight}"
def check_grad(org_model: Module, def check_grad(org_model: Module,
...@@ -236,9 +236,7 @@ def check_grad(org_model: Module, ...@@ -236,9 +236,7 @@ def check_grad(org_model: Module,
shard_weight = getattr_(sharded_model, suffix).weight shard_weight = getattr_(sharded_model, suffix).weight
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight): if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
shard_grad_list = [ shard_grad_list = [torch.zeros_like(shard_grad).to('cuda') for _ in range(dist.get_world_size(tp_group))]
torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(dist.get_world_size(tp_group))
]
dist.all_gather(shard_grad_list, shard_grad, tp_group) dist.all_gather(shard_grad_list, shard_grad, tp_group)
shard_grad = torch.cat(shard_grad_list, dim=dim) shard_grad = torch.cat(shard_grad_list, dim=dim)
......
...@@ -23,7 +23,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -23,7 +23,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \ org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = \
build_model_from_hybrid_plugin(model_fn, loss_fn, test_config) build_model_from_hybrid_plugin(model_fn, loss_fn, test_config)
org_loss, org_output, sharded_loss, sharded_output = \ org_loss, org_output, sharded_loss, sharded_output = \
run_forward_backward_with_hybrid_plugin( run_forward_backward_with_hybrid_plugin(
org_model, org_model,
...@@ -47,7 +46,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -47,7 +46,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
if org_model.__class__.__name__ == 'GPT2Model': if org_model.__class__.__name__ == 'GPT2Model':
check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol) check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
# check loss
check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol) check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
def unwrap(module): def unwrap(module):
...@@ -92,13 +90,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -92,13 +90,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
'num_microbatches': 4, 'num_microbatches': 4,
'enable_all_optimization': True, 'enable_all_optimization': True,
'use_lazy_init': True, 'use_lazy_init': True,
'precision': 'fp32', 'precision': 'fp16',
'initial_scale': 1,
}, { }, {
'tp_size': 1, 'tp_size': 1,
'pp_size': 2, 'pp_size': 2,
'num_microbatches': 4, 'num_microbatches': 4,
'enable_all_optimization': True, 'enable_all_optimization': True,
'use_lazy_init': False, 'use_lazy_init': True,
'precision': 'fp16', 'precision': 'fp16',
'initial_scale': 1, 'initial_scale': 1,
}, { }, {
...@@ -112,7 +111,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, ...@@ -112,7 +111,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
def run_gpt2_test(test_config): def run_gpt2_test(test_config):
# TODO: add test_config for TP+DP after supporting & debugging it # TODO: add test_config for TP+DP after supporting & debugging it
# TODO: check and debug TP+AMP
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment