Unverified Commit ff644ee5 authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

polish unitest test with titans (#1152)

parent f1f51990
...@@ -16,12 +16,10 @@ from colossalai.core import global_context as gpc ...@@ -16,12 +16,10 @@ from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn import CrossEntropyLoss from colossalai.nn import CrossEntropyLoss
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.utils import is_using_pp, get_dataloader from colossalai.utils import get_dataloader
from colossalai.pipeline.pipelinable import PipelinableContext from colossalai.pipeline.pipelinable import PipelinableContext
from tqdm import tqdm
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
from torchvision import transforms from torchvision import transforms
from titans.model.vit import vit_tiny_patch4_32
BATCH_SIZE = 4 BATCH_SIZE = 4
NUM_EPOCHS = 60 NUM_EPOCHS = 60
...@@ -41,6 +39,12 @@ def run_trainer(rank, world_size, port): ...@@ -41,6 +39,12 @@ def run_trainer(rank, world_size, port):
logger = get_dist_logger() logger = get_dist_logger()
pipelinable = PipelinableContext() pipelinable = PipelinableContext()
try:
from titans.model.vit import vit_tiny_patch4_32
except ImportError:
logger.warning('skip the test_cifar_with_data_pipeline_tensor test because titan is not installed')
logger.warning('please install titan from https://github.com/hpcaitech/Titans')
return
with pipelinable: with pipelinable:
model = vit_tiny_patch4_32() model = vit_tiny_patch4_32()
pipelinable.to_layer_list() pipelinable.to_layer_list()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment