Unverified Commit 6674a515 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

Conditionally import torchvision in Cosmos transformer (#11524)

fix
parent 784db0ea
...@@ -18,9 +18,9 @@ import numpy as np ...@@ -18,9 +18,9 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torchvision import transforms
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torchvision_available
from ..attention import FeedForward from ..attention import FeedForward
from ..attention_processor import Attention from ..attention_processor import Attention
from ..embeddings import Timesteps from ..embeddings import Timesteps
...@@ -29,6 +29,10 @@ from ..modeling_utils import ModelMixin ...@@ -29,6 +29,10 @@ from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm from ..normalization import RMSNorm
if is_torchvision_available():
from torchvision import transforms
class CosmosPatchEmbed(nn.Module): class CosmosPatchEmbed(nn.Module):
def __init__( def __init__(
self, in_channels: int, out_channels: int, patch_size: Tuple[int, int, int], bias: bool = True self, in_channels: int, out_channels: int, patch_size: Tuple[int, int, int], bias: bool = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment