Unverified Commit 01ea68b2 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[tests] remove T5 test skip decorator (#1271)

parent de498255
...@@ -2,12 +2,20 @@ import pytest ...@@ -2,12 +2,20 @@ import pytest
import transformers import transformers
import torch import torch
from hf_utils import split_model_and_compare_output from hf_utils import split_model_and_compare_output
from colossalai.fx.tracer.meta_patch import meta_patched_module
try:
import apex
@meta_patched_module.register(apex.normalization.FusedRMSNorm)
def apex_fused_layernorm(self, input):
return torch.empty(input.shape, device='meta')
except ImportError:
pass
BATCH_SIZE = 1 BATCH_SIZE = 1
SEQ_LENGHT = 16 SEQ_LENGHT = 16
@pytest.mark.skip('tracing failed')
def test_t5(): def test_t5():
MODEL_LIST = [ MODEL_LIST = [
transformers.T5Model, transformers.T5Model,
...@@ -15,7 +23,7 @@ def test_t5(): ...@@ -15,7 +23,7 @@ def test_t5():
transformers.T5EncoderModel, transformers.T5EncoderModel,
] ]
config = transformers.T5Config(d_model=128, num_layers=2) config = transformers.T5Config(vocab_size=100, d_model=128, num_layers=2)
def data_gen(): def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment