Unverified Commit f57d3495 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[FX] refactor experimental tracer and adapt it with hf models (#3157)

* pass gpt trace and meta_prop

* pass t5 trace and meta_prop

* [FX] refactor experimental tracer and adapt it with hf models

* pass all mainstream model zoo

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* fix CI

* skip tests

* fix CI

* using packaging version

* polish
parent b4295293
import pytest
import torch
from hf_tracer_utils import trace_model_and_compare_output
from packaging import version
from tests.kit.model_zoo import model_zoo
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_opt():
sub_registry = model_zoo.get_sub_registry('transformers_opt')
......
import pytest
import torch
from hf_tracer_utils import trace_model_and_compare_output
from packaging import version
from tests.kit.model_zoo import model_zoo
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_t5():
sub_registry = model_zoo.get_sub_registry('transformers_t5')
......
import pytest
import timm.models as tm
import torch
from packaging import version
from colossalai.fx import symbolic_trace
from colossalai._analyzer.fx import symbolic_trace
from tests.kit.model_zoo import model_zoo
......@@ -42,6 +42,7 @@ def trace_and_compare(model_cls, data, output_transform_fn, meta_args=None):
f'{model.__class__.__name__} has inconsistent outputs, {fx_output_val} vs {non_fx_output_val}'
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_timm_models():
torch.backends.cudnn.deterministic = True
......
import re
import pytest
import torch
from packaging import version
from torchaudio_utils import trace_and_compare
from tests.kit.model_zoo import model_zoo
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.12.0'), reason='torch version < 12')
def test_torchaudio_models():
torch.backends.cudnn.deterministic = True
sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
# FIXME(ver217): temporarily skip these models
if re.search(f'(conformer|emformer|tacotron|wav2vec2_base|hubert_base)', name):
continue
model = model_fn()
trace_and_compare(model,
data_gen_fn,
......
import torch
from colossalai.fx import symbolic_trace
from colossalai._analyzer.fx import symbolic_trace
def trace_and_compare(model, data_gen, output_transform_fn, need_meta=False, need_concrete=False):
......
import pytest
import torch
from colossalai.fx import symbolic_trace
from colossalai._analyzer.fx import symbolic_trace
from tests.kit.model_zoo import model_zoo
BATCH = 2
......
import pytest
import torch
from colossalai.fx import symbolic_trace
from colossalai._analyzer.fx import symbolic_trace
from tests.kit.model_zoo import model_zoo
BATCH = 2
......
import torch
from colossalai.fx import symbolic_trace
from colossalai._analyzer.fx import symbolic_trace
from tests.kit.model_zoo import model_zoo
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment