Commit f5f0ad26 authored by ver217's avatar ver217 Committed by Frank Lee
Browse files

fix bert unit test

parent 56636169
import torch
import transformers
from transformers import BertConfig, BertForSequenceClassification
from packaging import version
from torch.utils.data import SequentialSampler
from transformers import BertConfig, BertForSequenceClassification
from .registry import non_distributed_component_funcs
......@@ -39,14 +39,14 @@ def get_training_components():
num_layer = 2
def bert_model_builder(checkpoint):
config = BertConfig(
gradient_checkpointing=checkpoint,
hidden_size=hidden_dim,
intermediate_size=hidden_dim * 4,
num_attention_heads=num_head,
max_position_embeddings=sequence_length,
num_hidden_layers=num_layer,
)
config = BertConfig(gradient_checkpointing=checkpoint,
hidden_size=hidden_dim,
intermediate_size=hidden_dim * 4,
num_attention_heads=num_head,
max_position_embeddings=sequence_length,
num_hidden_layers=num_layer,
hidden_dropout_prob=0.,
attention_probs_dropout_prob=0.)
print('building BertForSequenceClassification model')
# adapting huggingface BertForSequenceClassification for single unitest calling interface
......
......@@ -13,6 +13,7 @@ from colossalai.utils import free_port
from colossalai.zero.shard_utils.tensor_shard_strategy import \
TensorShardStrategy
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._zero3_utils import cast_tensor_to_fp16
from tests.components_to_test.registry import non_distributed_component_funcs
from torch.nn.parallel import DistributedDataParallel as DDP
......@@ -45,8 +46,7 @@ def run_fwd_bwd_no_criterion(model, data, label, enable_autocast=False):
def run_dist(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
test_models = ['bert']
# repeated_computed_layers resnet18
test_models = ['repeated_computed_layers', 'resnet18', 'bert']
shard_strategy = TensorShardStrategy()
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
......@@ -65,8 +65,7 @@ def run_dist(rank, world_size, port):
run_fwd_bwd_no_criterion(model, data, label, False)
run_fwd_bwd_no_criterion(zero_model, data, label, False)
else:
# FIXME() data can be interger!
data, label = data.half().cuda(), label.cuda()
data, label = cast_tensor_to_fp16(data).cuda(), label.cuda()
run_fwd_bwd(model, data, label, criterion, False)
run_fwd_bwd(zero_model, data, label, criterion, False)
......@@ -76,7 +75,6 @@ def run_dist(rank, world_size, port):
check_grads(model, zero_model, loose=True)
@pytest.mark.skip(reason="Under development")
@pytest.mark.dist
@pytest.mark.parametrize("world_size", [1, 2, 4])
def test_shard_model_v2(world_size):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment