Commit f5f0ad26 authored by ver217's avatar ver217 Committed by Frank Lee
Browse files

fix bert unit test

parent 56636169
import torch import torch
import transformers import transformers
from transformers import BertConfig, BertForSequenceClassification
from packaging import version from packaging import version
from torch.utils.data import SequentialSampler from torch.utils.data import SequentialSampler
from transformers import BertConfig, BertForSequenceClassification
from .registry import non_distributed_component_funcs from .registry import non_distributed_component_funcs
...@@ -39,14 +39,14 @@ def get_training_components(): ...@@ -39,14 +39,14 @@ def get_training_components():
num_layer = 2 num_layer = 2
def bert_model_builder(checkpoint): def bert_model_builder(checkpoint):
config = BertConfig( config = BertConfig(gradient_checkpointing=checkpoint,
gradient_checkpointing=checkpoint, hidden_size=hidden_dim,
hidden_size=hidden_dim, intermediate_size=hidden_dim * 4,
intermediate_size=hidden_dim * 4, num_attention_heads=num_head,
num_attention_heads=num_head, max_position_embeddings=sequence_length,
max_position_embeddings=sequence_length, num_hidden_layers=num_layer,
num_hidden_layers=num_layer, hidden_dropout_prob=0.,
) attention_probs_dropout_prob=0.)
print('building BertForSequenceClassification model') print('building BertForSequenceClassification model')
# adapting huggingface BertForSequenceClassification for single unitest calling interface # adapting huggingface BertForSequenceClassification for single unitest calling interface
......
...@@ -13,6 +13,7 @@ from colossalai.utils import free_port ...@@ -13,6 +13,7 @@ from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \ from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from tests.components_to_test.registry import non_distributed_component_funcs from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
...@@ -45,8 +46,7 @@ def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False): ...@@ -45,8 +46,7 @@ def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False):
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['bert'] test_models = ['repeated_computed_layers', 'resnet18', 'bert']
# repeated_computed_layers resnet18
shard_strategy = TensorShardStrategy() shard_strategy = TensorShardStrategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
...@@ -65,8 +65,7 @@ def run_dist(rank, world_size, port): ...@@ -65,8 +65,7 @@ def run_dist(rank, world_size, port):
run_fwd_bwd_no_criterion(model, data, label, False) run_fwd_bwd_no_criterion(model, data, label, False)
run_fwd_bwd_no_criterion(zero_model, data, label, False) run_fwd_bwd_no_criterion(zero_model, data, label, False)
else: else:
# FIXME() data can be interger! data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
data, label = data.half().cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, False) run_fwd_bwd(model, data, label, criterion, False)
run_fwd_bwd(zero_model, data, label, criterion, False) run_fwd_bwd(zero_model, data, label, criterion, False)
...@@ -76,7 +75,6 @@ def run_dist(rank, world_size, port): ...@@ -76,7 +75,6 @@ def run_dist(rank, world_size, port):
check_grads(model, zero_model, loose=True) check_grads(model, zero_model, loose=True)
@pytest.mark.skip(reason="Under development")
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2, 4]) @pytest.mark.parametrize("world_size", [1, 2, 4])
def test_shard_model_v2(world_size): def test_shard_model_v2(world_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment