"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "7c8be770810835544e2652c6d053e77db83a0949"
Unverified Commit b893342f authored by Super Daniel's avatar Super Daniel Committed by GitHub
Browse files

[fx] test tracer on diffuser modules. (#1750)

* [fx] test tracer on diffuser modules.

* [fx] shorter seq_len.

* Update requirements-test.txt
parent b80b6eaa
diffusers
pytest pytest
torchvision torchvision
transformers transformers
......
import transformers
import torch
import pytest import pytest
import torch
import transformers
from utils import trace_model_and_compare_output from utils import trace_model_and_compare_output
BATCH_SIZE = 2 BATCH_SIZE = 2
SEQ_LENGHT = 16 SEQ_LENGTH = 16
def test_single_sentence_albert(): def test_single_sentence_albert():
...@@ -23,9 +23,9 @@ def test_single_sentence_albert(): ...@@ -23,9 +23,9 @@ def test_single_sentence_albert():
intermediate_size=256) intermediate_size=256)
def data_gen(): def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return meta_args return meta_args
......
import transformers
import torch
import pytest import pytest
import torch
import transformers
from utils import trace_model_and_compare_output from utils import trace_model_and_compare_output
BATCH_SIZE = 2 BATCH_SIZE = 2
SEQ_LENGHT = 16 SEQ_LENGTH = 16
def test_single_sentence_bert(): def test_single_sentence_bert():
...@@ -20,9 +20,9 @@ def test_single_sentence_bert(): ...@@ -20,9 +20,9 @@ def test_single_sentence_bert():
config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256) config = transformers.BertConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4, intermediate_size=256)
def data_gen(): def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) meta_args = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return meta_args return meta_args
......
import diffusers
import pytest
import torch
import transformers
from torch.fx import GraphModule
from utils import trace_model_and_compare_output
from colossalai.fx import ColoTracer
BATCH_SIZE = 2
SEQ_LENGTH = 5
HEIGHT = 224
WIDTH = 224
IN_CHANNELS = 3
LATENTS_SHAPE = (BATCH_SIZE, IN_CHANNELS, HEIGHT // 8, WIDTH // 8)
TIME_STEP = 2
def test_vae():
MODEL_LIST = [
diffusers.AutoencoderKL,
diffusers.VQModel,
]
for model_cls in MODEL_LIST:
model = model_cls()
sample = torch.zeros(LATENTS_SHAPE)
tracer = ColoTracer()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
model.eval()
gm.eval()
with torch.no_grad():
fx_out = gm(sample)
non_fx_out = model(sample)
assert torch.allclose(
fx_out['sample'],
non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
def test_clip():
MODEL_LIST = [
transformers.CLIPModel,
transformers.CLIPTextModel,
transformers.CLIPVisionModel,
]
CONFIG_LIST = [
transformers.CLIPConfig,
transformers.CLIPTextConfig,
transformers.CLIPVisionConfig,
]
def data_gen():
if isinstance(model, transformers.CLIPModel):
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
position_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
kwargs = dict(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
pixel_values=pixel_values)
elif isinstance(model, transformers.CLIPTextModel):
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
elif isinstance(model, transformers.CLIPVisionModel):
pixel_values = torch.zeros((BATCH_SIZE, IN_CHANNELS, HEIGHT, WIDTH), dtype=torch.float32)
kwargs = dict(pixel_values=pixel_values)
return kwargs
for model_cls, config in zip(MODEL_LIST, CONFIG_LIST):
model = model_cls(config=config())
trace_model_and_compare_output(model, data_gen)
@pytest.mark.skip(reason='cannot pass the test yet')
def test_unet():
MODEL_LIST = [
diffusers.UNet2DModel,
diffusers.UNet2DConditionModel,
]
for model_cls in MODEL_LIST:
model = model_cls()
sample = torch.zeros(LATENTS_SHAPE)
tracer = ColoTracer()
graph = tracer.trace(root=model)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
model.eval()
gm.eval()
with torch.no_grad():
fx_out = gm(sample, TIME_STEP)
non_fx_out = model(sample, TIME_STEP)
assert torch.allclose(
fx_out['sample'],
non_fx_out['sample']), f'{model.__class__.__name__} has inconsistent outputs, {fx_out} vs {non_fx_out}'
if __name__ == "__main__":
test_vae()
test_clip()
# skip because of failure
# test_unet()
import transformers
import torch
import pytest import pytest
import torch
import transformers
from utils import trace_model_and_compare_output from utils import trace_model_and_compare_output
BATCH_SIZE = 1 BATCH_SIZE = 1
SEQ_LENGHT = 16 SEQ_LENGTH = 16
def test_gpt(): def test_gpt():
...@@ -19,9 +19,9 @@ def test_gpt(): ...@@ -19,9 +19,9 @@ def test_gpt():
config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4) config = transformers.GPT2Config(n_position=64, n_layer=2, n_head=4)
def data_gen(): def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask) kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
return kwargs return kwargs
......
import pytest import pytest
import transformers
import torch import torch
import transformers
from utils import trace_model_and_compare_output from utils import trace_model_and_compare_output
BATCH_SIZE = 1 BATCH_SIZE = 1
SEQ_LENGHT = 16 SEQ_LENGTH = 16
def test_opt(): def test_opt():
...@@ -16,8 +16,8 @@ def test_opt(): ...@@ -16,8 +16,8 @@ def test_opt():
config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4) config = transformers.OPTConfig(hidden_size=128, num_hidden_layers=2, num_attention_heads=4)
def data_gen(): def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask) kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
return kwargs return kwargs
......
import pytest import pytest
import transformers
import torch import torch
import transformers
from utils import trace_model_and_compare_output from utils import trace_model_and_compare_output
BATCH_SIZE = 1 BATCH_SIZE = 1
SEQ_LENGHT = 16 SEQ_LENGTH = 16
def test_t5(): def test_t5():
...@@ -17,13 +17,13 @@ def test_t5(): ...@@ -17,13 +17,13 @@ def test_t5():
config = transformers.T5Config(d_model=128, num_layers=2) config = transformers.T5Config(d_model=128, num_layers=2)
def data_gen(): def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids) kwargs = dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
return kwargs return kwargs
def data_gen_for_encoder_only(): def data_gen_for_encoder_only():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64) input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
kwargs = dict(input_ids=input_ids) kwargs = dict(input_ids=input_ids)
return kwargs return kwargs
......
from numpy import isin
import torch import torch
from colossalai.fx import ColoTracer from numpy import isin
from torch.fx import GraphModule from torch.fx import GraphModule
from torch.utils._pytree import tree_flatten from torch.utils._pytree import tree_flatten
from colossalai.fx import ColoTracer
def trace_model_and_compare_output(model, data_gen): def trace_model_and_compare_output(model, data_gen):
# must turn on eval mode to ensure the output is consistent # must turn on eval mode to ensure the output is consistent
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment