Unverified Commit 852e7eba authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Use `config.num_channels` in CLIP-like modeling files (#20857)



Use config.num_channels in CLIP-like modeling files
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent d87e381f
...@@ -181,7 +181,11 @@ class ChineseCLIPVisionEmbeddings(nn.Module): ...@@ -181,7 +181,11 @@ class ChineseCLIPVisionEmbeddings(nn.Module):
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d( self.patch_embedding = nn.Conv2d(
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
) )
self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_patches = (self.image_size // self.patch_size) ** 2
......
...@@ -178,7 +178,11 @@ class CLIPVisionEmbeddings(nn.Module): ...@@ -178,7 +178,11 @@ class CLIPVisionEmbeddings(nn.Module):
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d( self.patch_embedding = nn.Conv2d(
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
) )
self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_patches = (self.image_size // self.patch_size) ** 2
......
...@@ -171,7 +171,11 @@ class CLIPSegVisionEmbeddings(nn.Module): ...@@ -171,7 +171,11 @@ class CLIPSegVisionEmbeddings(nn.Module):
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d( self.patch_embedding = nn.Conv2d(
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
) )
self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_patches = (self.image_size // self.patch_size) ** 2
......
...@@ -129,7 +129,11 @@ class XCLIPVisionEmbeddings(nn.Module): ...@@ -129,7 +129,11 @@ class XCLIPVisionEmbeddings(nn.Module):
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d( self.patch_embedding = nn.Conv2d(
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
) )
self.num_patches = (self.image_size // self.patch_size) ** 2 self.num_patches = (self.image_size // self.patch_size) ** 2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment