"...gpt/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "29695cf70c2652e4017bd76ff6337572f5b05035"
Unverified Commit 32c1b843 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

skip torchrec unittests if not installed (#1790)

parent 0b8161fa
import pytest
import torch
from colossalai.fx.tracer import meta_patch from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx.tracer.meta_patch.patched_function import python_ops from colossalai.fx.tracer.meta_patch.patched_function import python_ops
import torch from colossalai.fx.tracer.tracer import ColoTracer
from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor
from torchrec.modules.embedding_modules import EmbeddingBagCollection try:
from torchrec.modules.embedding_configs import EmbeddingBagConfig from torchrec.models import deepfm
from torchrec.models import deepfm, dlrm from torchrec.modules.embedding_configs import EmbeddingBagConfig
import colossalai.fx as fx from torchrec.modules.embedding_modules import EmbeddingBagCollection
import pdb from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
NOT_TORCHREC = False
except ImportError:
NOT_TORCHREC = True
from torch.fx import GraphModule from torch.fx import GraphModule
BATCH = 2 BATCH = 2
SHAPE = 10 SHAPE = 10
@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed')
def test_torchrec_deepfm_models(): def test_torchrec_deepfm_models():
MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch] MODEL_LIST = [deepfm.DenseArch, deepfm.FMInteractionArch, deepfm.OverArch, deepfm.SimpleDeepFMNN, deepfm.SparseArch]
......
from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
import torch import torch
from torchrec.sparse.jagged_tensor import KeyedTensor, KeyedJaggedTensor
from torchrec.modules.embedding_modules import EmbeddingBagCollection from colossalai.fx.tracer.tracer import ColoTracer
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.models import deepfm, dlrm try:
import colossalai.fx as fx from torchrec.models import dlrm
import pdb from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.modules.embedding_modules import EmbeddingBagCollection
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
NOT_TORCHREC = False
except ImportError:
NOT_TORCHREC = True
import pytest
from torch.fx import GraphModule from torch.fx import GraphModule
BATCH = 2 BATCH = 2
SHAPE = 10 SHAPE = 10
@pytest.mark.skipif(NOT_TORCHREC, reason='torchrec is not installed')
def test_torchrec_dlrm_models(): def test_torchrec_dlrm_models():
MODEL_LIST = [ MODEL_LIST = [
dlrm.DLRM, dlrm.DLRM,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment