Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SED_pytorch
Commits
f55a786e
Commit
f55a786e
authored
Jun 05, 2024
by
luopl
Browse files
Initial commit
parents
Pipeline
#1081
canceled with stages
Changes
181
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1401 additions
and
0 deletions
+1401
-0
open_clip/src/open_clip/model_configs/convnext_large_d_320.json
...lip/src/open_clip/model_configs/convnext_large_d_320.json
+20
-0
open_clip/src/open_clip/model_configs/convnext_small.json
open_clip/src/open_clip/model_configs/convnext_small.json
+18
-0
open_clip/src/open_clip/model_configs/convnext_tiny.json
open_clip/src/open_clip/model_configs/convnext_tiny.json
+18
-0
open_clip/src/open_clip/model_configs/convnext_xlarge.json
open_clip/src/open_clip/model_configs/convnext_xlarge.json
+18
-0
open_clip/src/open_clip/model_configs/convnext_xxlarge.json
open_clip/src/open_clip/model_configs/convnext_xxlarge.json
+18
-0
open_clip/src/open_clip/model_configs/convnext_xxlarge_320.json
...lip/src/open_clip/model_configs/convnext_xxlarge_320.json
+18
-0
open_clip/src/open_clip/model_configs/mt5-base-ViT-B-32.json
open_clip/src/open_clip/model_configs/mt5-base-ViT-B-32.json
+15
-0
open_clip/src/open_clip/model_configs/mt5-xl-ViT-H-14.json
open_clip/src/open_clip/model_configs/mt5-xl-ViT-H-14.json
+16
-0
open_clip/src/open_clip/model_configs/roberta-ViT-B-32.json
open_clip/src/open_clip/model_configs/roberta-ViT-B-32.json
+16
-0
open_clip/src/open_clip/model_configs/swin_base_patch4_window7_224.json
...open_clip/model_configs/swin_base_patch4_window7_224.json
+18
-0
open_clip/src/open_clip/model_configs/vit_medium_patch16_gap_256.json
...c/open_clip/model_configs/vit_medium_patch16_gap_256.json
+18
-0
open_clip/src/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json
...clip/model_configs/vit_relpos_medium_patch16_cls_224.json
+18
-0
open_clip/src/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json
...rc/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json
+15
-0
open_clip/src/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json
...c/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json
+16
-0
open_clip/src/open_clip/modified_resnet.py
open_clip/src/open_clip/modified_resnet.py
+181
-0
open_clip/src/open_clip/openai.py
open_clip/src/open_clip/openai.py
+144
-0
open_clip/src/open_clip/pretrained.py
open_clip/src/open_clip/pretrained.py
+354
-0
open_clip/src/open_clip/timm_model.py
open_clip/src/open_clip/timm_model.py
+146
-0
open_clip/src/open_clip/tokenizer.py
open_clip/src/open_clip/tokenizer.py
+201
-0
open_clip/src/open_clip/transform.py
open_clip/src/open_clip/transform.py
+133
-0
No files found.
open_clip/src/open_clip/model_configs/convnext_large_d_320.json
0 → 100644
View file @
f55a786e
{
"embed_dim"
:
768
,
"vision_cfg"
:
{
"timm_model_name"
:
"convnext_large"
,
"timm_model_pretrained"
:
false
,
"timm_pool"
:
""
,
"timm_proj"
:
"mlp"
,
"timm_drop"
:
0.1
,
"timm_drop_path"
:
0.1
,
"image_size"
:
320
},
"text_cfg"
:
{
"context_length"
:
77
,
"vocab_size"
:
49408
,
"width"
:
768
,
"heads"
:
12
,
"layers"
:
16
}
}
\ No newline at end of file
open_clip/src/open_clip/model_configs/convnext_small.json
0 → 100644
View file @
f55a786e
{
"embed_dim"
:
512
,
"vision_cfg"
:
{
"timm_model_name"
:
"convnext_small"
,
"timm_model_pretrained"
:
false
,
"timm_pool"
:
""
,
"timm_proj"
:
"linear"
,
"image_size"
:
224
},
"text_cfg"
:
{
"context_length"
:
77
,
"vocab_size"
:
49408
,
"width"
:
512
,
"heads"
:
8
,
"layers"
:
12
}
}
\ No newline at end of file
open_clip/src/open_clip/model_configs/convnext_tiny.json
0 → 100644
View file @
f55a786e
{
"embed_dim"
:
1024
,
"vision_cfg"
:
{
"timm_model_name"
:
"convnext_tiny"
,
"timm_model_pretrained"
:
false
,
"timm_pool"
:
""
,
"timm_proj"
:
"linear"
,
"image_size"
:
224
},
"text_cfg"
:
{
"context_length"
:
77
,
"vocab_size"
:
49408
,
"width"
:
512
,
"heads"
:
8
,
"layers"
:
12
}
}
\ No newline at end of file
open_clip/src/open_clip/model_configs/convnext_xlarge.json
0 → 100644
View file @
f55a786e
{
"embed_dim"
:
1024
,
"vision_cfg"
:
{
"timm_model_name"
:
"convnext_xlarge"
,
"timm_model_pretrained"
:
false
,
"timm_pool"
:
""
,
"timm_proj"
:
"linear"
,
"image_size"
:
224
},
"text_cfg"
:
{
"context_length"
:
77
,
"vocab_size"
:
49408
,
"width"
:
1024
,
"heads"
:
16
,
"layers"
:
16
}
}
\ No newline at end of file
open_clip/src/open_clip/model_configs/convnext_xxlarge.json
0 → 100644
View file @
f55a786e
{
"embed_dim"
:
1024
,
"vision_cfg"
:
{
"timm_model_name"
:
"convnext_xxlarge"
,
"timm_model_pretrained"
:
false
,
"timm_pool"
:
""
,
"timm_proj"
:
"linear"
,
"image_size"
:
256
},
"text_cfg"
:
{
"context_length"
:
77
,
"vocab_size"
:
49408
,
"width"
:
1024
,
"heads"
:
16
,
"layers"
:
24
}
}
\ No newline at end of file
open_clip/src/open_clip/model_configs/convnext_xxlarge_320.json
0 → 100644
View file @
f55a786e
{
"embed_dim"
:
1024
,
"vision_cfg"
:
{
"timm_model_name"
:
"convnext_xxlarge"
,
"timm_model_pretrained"
:
false
,
"timm_pool"
:
""
,
"timm_proj"
:
"linear"
,
"image_size"
:
320
},
"text_cfg"
:
{
"context_length"
:
77
,
"vocab_size"
:
49408
,
"width"
:
1024
,
"heads"
:
16
,
"layers"
:
24
}
}
\ No newline at end of file
open_clip/src/open_clip/model_configs/mt5-base-ViT-B-32.json
0 → 100644
View file @
f55a786e
{
"embed_dim"
:
512
,
"vision_cfg"
:
{
"image_size"
:
224
,
"layers"
:
12
,
"width"
:
768
,
"patch_size"
:
32
},
"text_cfg"
:
{
"hf_model_name"
:
"google/mt5-base"
,
"hf_tokenizer_name"
:
"google/mt5-base"
,
"proj"
:
"mlp"
,
"pooler_type"
:
"mean_pooler"
}
}
open_clip/src/open_clip/model_configs/mt5-xl-ViT-H-14.json
0 → 100644
View file @
f55a786e
{
"embed_dim"
:
1024
,
"vision_cfg"
:
{
"image_size"
:
224
,
"layers"
:
32
,
"width"
:
1280
,
"head_width"
:
80
,
"patch_size"
:
14
},
"text_cfg"
:
{
"hf_model_name"
:
"google/mt5-xl"
,
"hf_tokenizer_name"
:
"google/mt5-xl"
,
"proj"
:
"mlp"
,
"pooler_type"
:
"mean_pooler"
}
}
open_clip/src/open_clip/model_configs/roberta-ViT-B-32.json
0 → 100644
View file @
f55a786e
{
"embed_dim"
:
512
,
"quick_gelu"
:
true
,
"vision_cfg"
:
{
"image_size"
:
224
,
"layers"
:
12
,
"width"
:
768
,
"patch_size"
:
32
},
"text_cfg"
:
{
"hf_model_name"
:
"roberta-base"
,
"hf_tokenizer_name"
:
"roberta-base"
,
"proj"
:
"mlp"
,
"pooler_type"
:
"mean_pooler"
}
}
open_clip/src/open_clip/model_configs/swin_base_patch4_window7_224.json
0 → 100644
View file @
f55a786e
{
"embed_dim"
:
640
,
"vision_cfg"
:
{
"timm_model_name"
:
"swin_base_patch4_window7_224"
,
"timm_model_pretrained"
:
false
,
"timm_pool"
:
""
,
"timm_proj"
:
"linear"
,
"image_size"
:
224
},
"text_cfg"
:
{
"context_length"
:
77
,
"vocab_size"
:
49408
,
"width"
:
640
,
"heads"
:
10
,
"layers"
:
12
}
}
\ No newline at end of file
open_clip/src/open_clip/model_configs/vit_medium_patch16_gap_256.json
0 → 100644
View file @
f55a786e
{
"embed_dim"
:
512
,
"vision_cfg"
:
{
"timm_model_name"
:
"vit_medium_patch16_gap_256"
,
"timm_model_pretrained"
:
false
,
"timm_pool"
:
""
,
"timm_proj"
:
"linear"
,
"image_size"
:
256
},
"text_cfg"
:
{
"context_length"
:
77
,
"vocab_size"
:
49408
,
"width"
:
512
,
"heads"
:
8
,
"layers"
:
12
}
}
\ No newline at end of file
open_clip/src/open_clip/model_configs/vit_relpos_medium_patch16_cls_224.json
0 → 100644
View file @
f55a786e
{
"embed_dim"
:
512
,
"vision_cfg"
:
{
"timm_model_name"
:
"vit_relpos_medium_patch16_cls_224"
,
"timm_model_pretrained"
:
false
,
"timm_pool"
:
""
,
"timm_proj"
:
"linear"
,
"image_size"
:
224
},
"text_cfg"
:
{
"context_length"
:
77
,
"vocab_size"
:
49408
,
"width"
:
512
,
"heads"
:
8
,
"layers"
:
12
}
}
\ No newline at end of file
open_clip/src/open_clip/model_configs/xlm-roberta-base-ViT-B-32.json
0 → 100644
View file @
f55a786e
{
"embed_dim"
:
512
,
"vision_cfg"
:
{
"image_size"
:
224
,
"layers"
:
12
,
"width"
:
768
,
"patch_size"
:
32
},
"text_cfg"
:
{
"hf_model_name"
:
"xlm-roberta-base"
,
"hf_tokenizer_name"
:
"xlm-roberta-base"
,
"proj"
:
"mlp"
,
"pooler_type"
:
"mean_pooler"
}
}
open_clip/src/open_clip/model_configs/xlm-roberta-large-ViT-H-14.json
0 → 100644
View file @
f55a786e
{
"embed_dim"
:
1024
,
"vision_cfg"
:
{
"image_size"
:
224
,
"layers"
:
32
,
"width"
:
1280
,
"head_width"
:
80
,
"patch_size"
:
14
},
"text_cfg"
:
{
"hf_model_name"
:
"xlm-roberta-large"
,
"hf_tokenizer_name"
:
"xlm-roberta-large"
,
"proj"
:
"mlp"
,
"pooler_type"
:
"mean_pooler"
}
}
open_clip/src/open_clip/modified_resnet.py
0 → 100644
View file @
f55a786e
from
collections
import
OrderedDict
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
open_clip.utils
import
freeze_batch_norm_2d
class
Bottleneck
(
nn
.
Module
):
expansion
=
4
def
__init__
(
self
,
inplanes
,
planes
,
stride
=
1
):
super
().
__init__
()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self
.
conv1
=
nn
.
Conv2d
(
inplanes
,
planes
,
1
,
bias
=
False
)
self
.
bn1
=
nn
.
BatchNorm2d
(
planes
)
self
.
act1
=
nn
.
ReLU
(
inplace
=
True
)
self
.
conv2
=
nn
.
Conv2d
(
planes
,
planes
,
3
,
padding
=
1
,
bias
=
False
)
self
.
bn2
=
nn
.
BatchNorm2d
(
planes
)
self
.
act2
=
nn
.
ReLU
(
inplace
=
True
)
self
.
avgpool
=
nn
.
AvgPool2d
(
stride
)
if
stride
>
1
else
nn
.
Identity
()
self
.
conv3
=
nn
.
Conv2d
(
planes
,
planes
*
self
.
expansion
,
1
,
bias
=
False
)
self
.
bn3
=
nn
.
BatchNorm2d
(
planes
*
self
.
expansion
)
self
.
act3
=
nn
.
ReLU
(
inplace
=
True
)
self
.
downsample
=
None
self
.
stride
=
stride
if
stride
>
1
or
inplanes
!=
planes
*
Bottleneck
.
expansion
:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self
.
downsample
=
nn
.
Sequential
(
OrderedDict
([
(
"-1"
,
nn
.
AvgPool2d
(
stride
)),
(
"0"
,
nn
.
Conv2d
(
inplanes
,
planes
*
self
.
expansion
,
1
,
stride
=
1
,
bias
=
False
)),
(
"1"
,
nn
.
BatchNorm2d
(
planes
*
self
.
expansion
))
]))
def
forward
(
self
,
x
:
torch
.
Tensor
):
identity
=
x
out
=
self
.
act1
(
self
.
bn1
(
self
.
conv1
(
x
)))
out
=
self
.
act2
(
self
.
bn2
(
self
.
conv2
(
out
)))
out
=
self
.
avgpool
(
out
)
out
=
self
.
bn3
(
self
.
conv3
(
out
))
if
self
.
downsample
is
not
None
:
identity
=
self
.
downsample
(
x
)
out
+=
identity
out
=
self
.
act3
(
out
)
return
out
class
AttentionPool2d
(
nn
.
Module
):
def
__init__
(
self
,
spacial_dim
:
int
,
embed_dim
:
int
,
num_heads
:
int
,
output_dim
:
int
=
None
):
super
().
__init__
()
self
.
positional_embedding
=
nn
.
Parameter
(
torch
.
randn
(
spacial_dim
**
2
+
1
,
embed_dim
)
/
embed_dim
**
0.5
)
self
.
k_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
q_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
v_proj
=
nn
.
Linear
(
embed_dim
,
embed_dim
)
self
.
c_proj
=
nn
.
Linear
(
embed_dim
,
output_dim
or
embed_dim
)
self
.
num_heads
=
num_heads
def
forward
(
self
,
x
):
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
],
x
.
shape
[
2
]
*
x
.
shape
[
3
]).
permute
(
2
,
0
,
1
)
# NCHW -> (HW)NC
x
=
torch
.
cat
([
x
.
mean
(
dim
=
0
,
keepdim
=
True
),
x
],
dim
=
0
)
# (HW+1)NC
x
=
x
+
self
.
positional_embedding
[:,
None
,
:].
to
(
x
.
dtype
)
# (HW+1)NC
x
,
_
=
F
.
multi_head_attention_forward
(
query
=
x
,
key
=
x
,
value
=
x
,
embed_dim_to_check
=
x
.
shape
[
-
1
],
num_heads
=
self
.
num_heads
,
q_proj_weight
=
self
.
q_proj
.
weight
,
k_proj_weight
=
self
.
k_proj
.
weight
,
v_proj_weight
=
self
.
v_proj
.
weight
,
in_proj_weight
=
None
,
in_proj_bias
=
torch
.
cat
([
self
.
q_proj
.
bias
,
self
.
k_proj
.
bias
,
self
.
v_proj
.
bias
]),
bias_k
=
None
,
bias_v
=
None
,
add_zero_attn
=
False
,
dropout_p
=
0.
,
out_proj_weight
=
self
.
c_proj
.
weight
,
out_proj_bias
=
self
.
c_proj
.
bias
,
use_separate_proj_weight
=
True
,
training
=
self
.
training
,
need_weights
=
False
)
return
x
[
0
]
class
ModifiedResNet
(
nn
.
Module
):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def
__init__
(
self
,
layers
,
output_dim
,
heads
,
image_size
=
224
,
width
=
64
):
super
().
__init__
()
self
.
output_dim
=
output_dim
self
.
image_size
=
image_size
# the 3-layer stem
self
.
conv1
=
nn
.
Conv2d
(
3
,
width
//
2
,
kernel_size
=
3
,
stride
=
2
,
padding
=
1
,
bias
=
False
)
self
.
bn1
=
nn
.
BatchNorm2d
(
width
//
2
)
self
.
act1
=
nn
.
ReLU
(
inplace
=
True
)
self
.
conv2
=
nn
.
Conv2d
(
width
//
2
,
width
//
2
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
)
self
.
bn2
=
nn
.
BatchNorm2d
(
width
//
2
)
self
.
act2
=
nn
.
ReLU
(
inplace
=
True
)
self
.
conv3
=
nn
.
Conv2d
(
width
//
2
,
width
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
)
self
.
bn3
=
nn
.
BatchNorm2d
(
width
)
self
.
act3
=
nn
.
ReLU
(
inplace
=
True
)
self
.
avgpool
=
nn
.
AvgPool2d
(
2
)
# residual layers
self
.
_inplanes
=
width
# this is a *mutable* variable used during construction
self
.
layer1
=
self
.
_make_layer
(
width
,
layers
[
0
])
self
.
layer2
=
self
.
_make_layer
(
width
*
2
,
layers
[
1
],
stride
=
2
)
self
.
layer3
=
self
.
_make_layer
(
width
*
4
,
layers
[
2
],
stride
=
2
)
self
.
layer4
=
self
.
_make_layer
(
width
*
8
,
layers
[
3
],
stride
=
2
)
embed_dim
=
width
*
32
# the ResNet feature dimension
self
.
attnpool
=
AttentionPool2d
(
image_size
//
32
,
embed_dim
,
heads
,
output_dim
)
self
.
init_parameters
()
def
_make_layer
(
self
,
planes
,
blocks
,
stride
=
1
):
layers
=
[
Bottleneck
(
self
.
_inplanes
,
planes
,
stride
)]
self
.
_inplanes
=
planes
*
Bottleneck
.
expansion
for
_
in
range
(
1
,
blocks
):
layers
.
append
(
Bottleneck
(
self
.
_inplanes
,
planes
))
return
nn
.
Sequential
(
*
layers
)
def
init_parameters
(
self
):
if
self
.
attnpool
is
not
None
:
std
=
self
.
attnpool
.
c_proj
.
in_features
**
-
0.5
nn
.
init
.
normal_
(
self
.
attnpool
.
q_proj
.
weight
,
std
=
std
)
nn
.
init
.
normal_
(
self
.
attnpool
.
k_proj
.
weight
,
std
=
std
)
nn
.
init
.
normal_
(
self
.
attnpool
.
v_proj
.
weight
,
std
=
std
)
nn
.
init
.
normal_
(
self
.
attnpool
.
c_proj
.
weight
,
std
=
std
)
for
resnet_block
in
[
self
.
layer1
,
self
.
layer2
,
self
.
layer3
,
self
.
layer4
]:
for
name
,
param
in
resnet_block
.
named_parameters
():
if
name
.
endswith
(
"bn3.weight"
):
nn
.
init
.
zeros_
(
param
)
def
lock
(
self
,
unlocked_groups
=
0
,
freeze_bn_stats
=
False
):
assert
unlocked_groups
==
0
,
'partial locking not currently supported for this model'
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
if
freeze_bn_stats
:
freeze_batch_norm_2d
(
self
)
@
torch
.
jit
.
ignore
def
set_grad_checkpointing
(
self
,
enable
=
True
):
# FIXME support for non-transformer
pass
def
stem
(
self
,
x
):
x
=
self
.
act1
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
self
.
act2
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
self
.
act3
(
self
.
bn3
(
self
.
conv3
(
x
)))
x
=
self
.
avgpool
(
x
)
return
x
def
forward
(
self
,
x
):
x
=
self
.
stem
(
x
)
x
=
self
.
layer1
(
x
)
x
=
self
.
layer2
(
x
)
x
=
self
.
layer3
(
x
)
x
=
self
.
layer4
(
x
)
x
=
self
.
attnpool
(
x
)
return
x
open_clip/src/open_clip/openai.py
0 → 100644
View file @
f55a786e
""" OpenAI pretrained model functions
Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import
os
import
warnings
from
typing
import
List
,
Optional
,
Union
import
torch
from
.model
import
build_model_from_openai_state_dict
,
convert_weights_to_lp
,
get_cast_dtype
from
.pretrained
import
get_pretrained_url
,
list_pretrained_models_by_tag
,
download_pretrained_from_url
__all__
=
[
"list_openai_models"
,
"load_openai_model"
]
def
list_openai_models
()
->
List
[
str
]:
"""Returns the names of available CLIP models"""
return
list_pretrained_models_by_tag
(
'openai'
)
def
load_openai_model
(
name
:
str
,
precision
:
Optional
[
str
]
=
None
,
device
:
Optional
[
Union
[
str
,
torch
.
device
]]
=
None
,
jit
:
bool
=
True
,
cache_dir
:
Optional
[
str
]
=
None
,
):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
precision: str
Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
cache_dir : Optional[str]
The directory to cache the downloaded model weights
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if
device
is
None
:
device
=
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
if
precision
is
None
:
precision
=
'fp32'
if
device
==
'cpu'
else
'fp16'
if
get_pretrained_url
(
name
,
'openai'
):
model_path
=
download_pretrained_from_url
(
get_pretrained_url
(
name
,
'openai'
),
cache_dir
=
cache_dir
)
elif
os
.
path
.
isfile
(
name
):
model_path
=
name
else
:
raise
RuntimeError
(
f
"Model
{
name
}
not found; available models =
{
list_openai_models
()
}
"
)
try
:
# loading JIT archive
model
=
torch
.
jit
.
load
(
model_path
,
map_location
=
device
if
jit
else
"cpu"
).
eval
()
state_dict
=
None
except
RuntimeError
:
# loading saved state dict
if
jit
:
warnings
.
warn
(
f
"File
{
model_path
}
is not a JIT archive. Loading as a state dict instead"
)
jit
=
False
state_dict
=
torch
.
load
(
model_path
,
map_location
=
"cpu"
)
if
not
jit
:
# Build a non-jit model from the OpenAI jitted model state dict
cast_dtype
=
get_cast_dtype
(
precision
)
try
:
model
=
build_model_from_openai_state_dict
(
state_dict
or
model
.
state_dict
(),
cast_dtype
=
cast_dtype
)
except
KeyError
:
sd
=
{
k
[
7
:]:
v
for
k
,
v
in
state_dict
[
"state_dict"
].
items
()}
model
=
build_model_from_openai_state_dict
(
sd
,
cast_dtype
=
cast_dtype
)
# model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
model
=
model
.
to
(
device
)
if
precision
.
startswith
(
'amp'
)
or
precision
==
'fp32'
:
model
.
float
()
elif
precision
==
'bf16'
:
convert_weights_to_lp
(
model
,
dtype
=
torch
.
bfloat16
)
return
model
# patch the device names
device_holder
=
torch
.
jit
.
trace
(
lambda
:
torch
.
ones
([]).
to
(
torch
.
device
(
device
)),
example_inputs
=
[])
device_node
=
[
n
for
n
in
device_holder
.
graph
.
findAllNodes
(
"prim::Constant"
)
if
"Device"
in
repr
(
n
)][
-
1
]
def
patch_device
(
module
):
try
:
graphs
=
[
module
.
graph
]
if
hasattr
(
module
,
"graph"
)
else
[]
except
RuntimeError
:
graphs
=
[]
if
hasattr
(
module
,
"forward1"
):
graphs
.
append
(
module
.
forward1
.
graph
)
for
graph
in
graphs
:
for
node
in
graph
.
findAllNodes
(
"prim::Constant"
):
if
"value"
in
node
.
attributeNames
()
and
str
(
node
[
"value"
]).
startswith
(
"cuda"
):
node
.
copyAttributes
(
device_node
)
model
.
apply
(
patch_device
)
patch_device
(
model
.
encode_image
)
patch_device
(
model
.
encode_text
)
# patch dtype to float32 (typically for CPU)
if
precision
==
'fp32'
:
float_holder
=
torch
.
jit
.
trace
(
lambda
:
torch
.
ones
([]).
float
(),
example_inputs
=
[])
float_input
=
list
(
float_holder
.
graph
.
findNode
(
"aten::to"
).
inputs
())[
1
]
float_node
=
float_input
.
node
()
def
patch_float
(
module
):
try
:
graphs
=
[
module
.
graph
]
if
hasattr
(
module
,
"graph"
)
else
[]
except
RuntimeError
:
graphs
=
[]
if
hasattr
(
module
,
"forward1"
):
graphs
.
append
(
module
.
forward1
.
graph
)
for
graph
in
graphs
:
for
node
in
graph
.
findAllNodes
(
"aten::to"
):
inputs
=
list
(
node
.
inputs
())
for
i
in
[
1
,
2
]:
# dtype can be the second or third argument to aten::to()
if
inputs
[
i
].
node
()[
"value"
]
==
5
:
inputs
[
i
].
node
().
copyAttributes
(
float_node
)
model
.
apply
(
patch_float
)
patch_float
(
model
.
encode_image
)
patch_float
(
model
.
encode_text
)
model
.
float
()
# ensure image_size attr available at consistent location for both jit and non-jit
model
.
visual
.
image_size
=
model
.
input_resolution
.
item
()
return
model
open_clip/src/open_clip/pretrained.py
0 → 100644
View file @
f55a786e
import
hashlib
import
os
import
urllib
import
warnings
from
functools
import
partial
from
typing
import
Dict
,
Union
from
tqdm
import
tqdm
from
.version
import
__version__
try
:
from
huggingface_hub
import
hf_hub_download
hf_hub_download
=
partial
(
hf_hub_download
,
library_name
=
"open_clip"
,
library_version
=
__version__
)
_has_hf_hub
=
True
except
ImportError
:
hf_hub_download
=
None
_has_hf_hub
=
False
def
_pcfg
(
url
=
''
,
hf_hub
=
''
,
mean
=
None
,
std
=
None
):
return
dict
(
url
=
url
,
hf_hub
=
hf_hub
,
mean
=
mean
,
std
=
std
,
)
_RN50
=
dict
(
openai
=
_pcfg
(
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"
),
yfcc15m
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"
),
cc12m
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"
),
)
_RN50_quickgelu
=
dict
(
openai
=
_pcfg
(
"https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"
),
yfcc15m
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"
),
cc12m
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"
),
)
_RN101
=
dict
(
openai
=
_pcfg
(
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"
),
yfcc15m
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"
),
)
_RN101_quickgelu
=
dict
(
openai
=
_pcfg
(
"https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"
),
yfcc15m
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"
),
)
_RN50x4
=
dict
(
openai
=
_pcfg
(
"https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"
),
)
_RN50x16
=
dict
(
openai
=
_pcfg
(
"https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"
),
)
_RN50x64
=
dict
(
openai
=
_pcfg
(
"https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"
),
)
_VITB32
=
dict
(
openai
=
_pcfg
(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"
),
laion400m_e31
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"
),
laion400m_e32
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"
),
laion2b_e16
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"
),
laion2b_s34b_b79k
=
_pcfg
(
hf_hub
=
'laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'
)
)
_VITB32_quickgelu
=
dict
(
openai
=
_pcfg
(
"https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"
),
laion400m_e31
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"
),
laion400m_e32
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"
),
)
_VITB16
=
dict
(
openai
=
_pcfg
(
"https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"
),
laion400m_e31
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"
),
laion400m_e32
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"
),
# laion400m_32k=_pcfg(
# url="",
# mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
# laion400m_64k=_pcfg(
# url="",
# mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
laion2b_s34b_b88k
=
_pcfg
(
hf_hub
=
'laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'
),
)
_VITB16_PLUS_240
=
dict
(
laion400m_e31
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"
),
laion400m_e32
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"
),
)
_VITL14
=
dict
(
openai
=
_pcfg
(
"https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"
),
laion400m_e31
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"
),
laion400m_e32
=
_pcfg
(
"https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"
),
laion2b_s32b_b82k
=
_pcfg
(
hf_hub
=
'laion/CLIP-ViT-L-14-laion2B-s32B-b82K/'
,
mean
=
(
0.5
,
0.5
,
0.5
),
std
=
(
0.5
,
0.5
,
0.5
)),
)
_VITL14_336
=
dict
(
openai
=
_pcfg
(
"https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"
),
)
_VITH14
=
dict
(
laion2b_s32b_b79k
=
_pcfg
(
hf_hub
=
'laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'
),
)
_VITg14
=
dict
(
laion2b_s12b_b42k
=
_pcfg
(
hf_hub
=
'laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'
),
)
_VITbigG14
=
dict
(
laion2b_s39b_b160k
=
_pcfg
(
hf_hub
=
'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'
),
)
_robertaViTB32
=
dict
(
laion2b_s12b_b32k
=
_pcfg
(
hf_hub
=
'laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'
),
)
_xlmRobertaBaseViTB32
=
dict
(
laion5b_s13b_b90k
=
_pcfg
(
hf_hub
=
'laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'
),
)
_xlmRobertaLargeFrozenViTH14
=
dict
(
frozen_laion5b_s13b_b90k
=
_pcfg
(
hf_hub
=
'laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'
),
)
_convnext_base
=
dict
(
laion400m_s13b_b51k
=
_pcfg
(
hf_hub
=
'convnext_base-laion400M-s13B-b51K'
),
)
_convnext_base_w
=
dict
(
laion2b_s13b_b82k
=
_pcfg
(
hf_hub
=
'laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'
),
laion2b_s13b_b82k_augreg
=
_pcfg
(
hf_hub
=
'laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'
),
laion_aesthetic_s13b_b82k
=
_pcfg
(
hf_hub
=
'laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'
),
)
_convnext_base_w_320
=
dict
(
laion_aesthetic_s13b_b82k
=
_pcfg
(
hf_hub
=
'laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'
),
laion_aesthetic_s13b_b82k_augreg
=
_pcfg
(
hf_hub
=
'laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'
),
)
_convnext_large_d_320
=
dict
(
laion2b_s29b_b131k_ft_soup
=
_pcfg
(
hf_hub
=
'laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'
),
)
_PRETRAINED
=
{
"RN50"
:
_RN50
,
"RN50-quickgelu"
:
_RN50_quickgelu
,
"RN101"
:
_RN101
,
"RN101-quickgelu"
:
_RN101_quickgelu
,
"RN50x4"
:
_RN50x4
,
"RN50x16"
:
_RN50x16
,
"RN50x64"
:
_RN50x64
,
"ViT-B-32"
:
_VITB32
,
"ViT-B-32-quickgelu"
:
_VITB32_quickgelu
,
"ViT-B-16"
:
_VITB16
,
"ViT-B-16-plus-240"
:
_VITB16_PLUS_240
,
"ViT-L-14"
:
_VITL14
,
"ViT-L-14-336"
:
_VITL14_336
,
"ViT-H-14"
:
_VITH14
,
"ViT-g-14"
:
_VITg14
,
"ViT-bigG-14"
:
_VITbigG14
,
"roberta-ViT-B-32"
:
_robertaViTB32
,
"xlm-roberta-base-ViT-B-32"
:
_xlmRobertaBaseViTB32
,
"xlm-roberta-large-ViT-H-14"
:
_xlmRobertaLargeFrozenViTH14
,
"convnext_base"
:
_convnext_base
,
"convnext_base_w"
:
_convnext_base_w
,
"convnext_base_w_320"
:
_convnext_base_w_320
,
"convnext_large_d_320"
:
_convnext_large_d_320
,
}
def
_clean_tag
(
tag
:
str
):
# normalize pretrained tags
return
tag
.
lower
().
replace
(
'-'
,
'_'
)
def
list_pretrained
(
as_str
:
bool
=
False
):
""" returns list of pretrained models
Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
"""
return
[
':'
.
join
([
k
,
t
])
if
as_str
else
(
k
,
t
)
for
k
in
_PRETRAINED
.
keys
()
for
t
in
_PRETRAINED
[
k
].
keys
()]
def
list_pretrained_models_by_tag
(
tag
:
str
):
""" return all models having the specified pretrain tag """
models
=
[]
tag
=
_clean_tag
(
tag
)
for
k
in
_PRETRAINED
.
keys
():
if
tag
in
_PRETRAINED
[
k
]:
models
.
append
(
k
)
return
models
def
list_pretrained_tags_by_model
(
model
:
str
):
""" return all pretrain tags for the specified model architecture """
tags
=
[]
if
model
in
_PRETRAINED
:
tags
.
extend
(
_PRETRAINED
[
model
].
keys
())
return
tags
def
is_pretrained_cfg
(
model
:
str
,
tag
:
str
):
if
model
not
in
_PRETRAINED
:
return
False
return
_clean_tag
(
tag
)
in
_PRETRAINED
[
model
]
def
get_pretrained_cfg
(
model
:
str
,
tag
:
str
):
if
model
not
in
_PRETRAINED
:
return
{}
model_pretrained
=
_PRETRAINED
[
model
]
return
model_pretrained
.
get
(
_clean_tag
(
tag
),
{})
def
get_pretrained_url
(
model
:
str
,
tag
:
str
):
cfg
=
get_pretrained_cfg
(
model
,
_clean_tag
(
tag
))
return
cfg
.
get
(
'url'
,
''
)
def
download_pretrained_from_url
(
url
:
str
,
cache_dir
:
Union
[
str
,
None
]
=
None
,
):
if
not
cache_dir
:
cache_dir
=
os
.
path
.
expanduser
(
"~/.cache/clip"
)
os
.
makedirs
(
cache_dir
,
exist_ok
=
True
)
filename
=
os
.
path
.
basename
(
url
)
if
'openaipublic'
in
url
:
expected_sha256
=
url
.
split
(
"/"
)[
-
2
]
elif
'mlfoundations'
in
url
:
expected_sha256
=
os
.
path
.
splitext
(
filename
)[
0
].
split
(
"-"
)[
-
1
]
else
:
expected_sha256
=
''
download_target
=
os
.
path
.
join
(
cache_dir
,
filename
)
if
os
.
path
.
exists
(
download_target
)
and
not
os
.
path
.
isfile
(
download_target
):
raise
RuntimeError
(
f
"
{
download_target
}
exists and is not a regular file"
)
if
os
.
path
.
isfile
(
download_target
):
if
expected_sha256
:
if
hashlib
.
sha256
(
open
(
download_target
,
"rb"
).
read
()).
hexdigest
().
startswith
(
expected_sha256
):
return
download_target
else
:
warnings
.
warn
(
f
"
{
download_target
}
exists, but the SHA256 checksum does not match; re-downloading the file"
)
else
:
return
download_target
with
urllib
.
request
.
urlopen
(
url
)
as
source
,
open
(
download_target
,
"wb"
)
as
output
:
with
tqdm
(
total
=
int
(
source
.
headers
.
get
(
"Content-Length"
)),
ncols
=
80
,
unit
=
'iB'
,
unit_scale
=
True
)
as
loop
:
while
True
:
buffer
=
source
.
read
(
8192
)
if
not
buffer
:
break
output
.
write
(
buffer
)
loop
.
update
(
len
(
buffer
))
if
expected_sha256
and
not
hashlib
.
sha256
(
open
(
download_target
,
"rb"
).
read
()).
hexdigest
().
startswith
(
expected_sha256
):
raise
RuntimeError
(
f
"Model has been downloaded but the SHA256 checksum does not not match"
)
return
download_target
def
has_hf_hub
(
necessary
=
False
):
if
not
_has_hf_hub
and
necessary
:
# if no HF Hub module installed, and it is necessary to continue, raise error
raise
RuntimeError
(
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.'
)
return
_has_hf_hub
def
download_pretrained_from_hf
(
model_id
:
str
,
filename
:
str
=
'open_clip_pytorch_model.bin'
,
revision
=
None
,
cache_dir
:
Union
[
str
,
None
]
=
None
,
):
has_hf_hub
(
True
)
# cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
# ConvNeXt-B
cached_file
=
'./weights/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/open_clip_pytorch_model.bin'
#ConvNeXt - L
# cached_file = './weights/laionCLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/open_clip_pytorch_model.bin'
return
cached_file
def
download_pretrained
(
cfg
:
Dict
,
force_hf_hub
:
bool
=
False
,
cache_dir
:
Union
[
str
,
None
]
=
None
,
):
target
=
''
if
not
cfg
:
return
target
download_url
=
cfg
.
get
(
'url'
,
''
)
download_hf_hub
=
cfg
.
get
(
'hf_hub'
,
''
)
if
download_hf_hub
and
force_hf_hub
:
# use HF hub even if url exists
download_url
=
''
if
download_url
:
target
=
download_pretrained_from_url
(
download_url
,
cache_dir
=
cache_dir
)
elif
download_hf_hub
:
has_hf_hub
(
True
)
# we assume the hf_hub entries in pretrained config combine model_id + filename in
# 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
# use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
model_id
,
filename
=
os
.
path
.
split
(
download_hf_hub
)
if
filename
:
target
=
download_pretrained_from_hf
(
model_id
,
filename
=
filename
,
cache_dir
=
cache_dir
)
else
:
target
=
download_pretrained_from_hf
(
model_id
,
cache_dir
=
cache_dir
)
return
target
open_clip/src/open_clip/timm_model.py
0 → 100644
View file @
f55a786e
""" timm model adapter
Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
"""
import
logging
from
collections
import
OrderedDict
import
torch
import
torch.nn
as
nn
from
einops
import
rearrange
try
:
import
timm
from
timm.models.layers
import
Mlp
,
to_2tuple
try
:
# old timm imports < 0.8.1
from
timm.models.layers.attention_pool2d
import
RotAttentionPool2d
from
timm.models.layers.attention_pool2d
import
AttentionPool2d
as
AbsAttentionPool2d
except
ImportError
:
# new timm imports >= 0.8.1
from
timm.layers
import
RotAttentionPool2d
from
timm.layers
import
AttentionPool2d
as
AbsAttentionPool2d
except
ImportError
:
timm
=
None
from
.utils
import
freeze_batch_norm_2d
class
TimmModel
(
nn
.
Module
):
""" timm model adapter
# FIXME this adapter is a work in progress, may change in ways that break weight compat
"""
def
__init__
(
self
,
model_name
,
embed_dim
,
image_size
=
224
,
pool
=
'avg'
,
proj
=
'linear'
,
proj_bias
=
False
,
drop
=
0.
,
drop_path
=
None
,
pretrained
=
False
,
):
super
().
__init__
()
if
timm
is
None
:
raise
RuntimeError
(
"Please `pip install timm` to use timm models."
)
self
.
image_size
=
to_2tuple
(
image_size
)
timm_kwargs
=
{}
if
drop_path
is
not
None
:
timm_kwargs
[
'drop_path_rate'
]
=
drop_path
self
.
trunk
=
timm
.
create_model
(
model_name
,
pretrained
=
pretrained
,
**
timm_kwargs
)
feat_size
=
self
.
trunk
.
default_cfg
.
get
(
'pool_size'
,
None
)
feature_ndim
=
1
if
not
feat_size
else
2
if
pool
in
(
'abs_attn'
,
'rot_attn'
):
assert
feature_ndim
==
2
# if attn pooling used, remove both classifier and default pool
self
.
trunk
.
reset_classifier
(
0
,
global_pool
=
''
)
else
:
# reset global pool if pool config set, otherwise leave as network default
reset_kwargs
=
dict
(
global_pool
=
pool
)
if
pool
else
{}
self
.
trunk
.
reset_classifier
(
0
,
**
reset_kwargs
)
prev_chs
=
self
.
trunk
.
num_features
head_layers
=
OrderedDict
()
if
pool
==
'abs_attn'
:
head_layers
[
'pool'
]
=
AbsAttentionPool2d
(
prev_chs
,
feat_size
=
feat_size
,
out_features
=
embed_dim
)
prev_chs
=
embed_dim
elif
pool
==
'rot_attn'
:
head_layers
[
'pool'
]
=
RotAttentionPool2d
(
prev_chs
,
out_features
=
embed_dim
)
prev_chs
=
embed_dim
else
:
assert
proj
,
'projection layer needed if non-attention pooling is used.'
# NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
if
proj
==
'linear'
:
head_layers
[
'drop'
]
=
nn
.
Dropout
(
drop
)
head_layers
[
'proj'
]
=
nn
.
Linear
(
prev_chs
,
embed_dim
,
bias
=
proj_bias
)
elif
proj
==
'mlp'
:
head_layers
[
'mlp'
]
=
Mlp
(
prev_chs
,
2
*
embed_dim
,
embed_dim
,
drop
=
(
drop
,
0
),
bias
=
(
True
,
proj_bias
))
self
.
head
=
nn
.
Sequential
(
head_layers
)
def
lock
(
self
,
unlocked_groups
=
0
,
freeze_bn_stats
=
False
):
""" lock modules
Args:
unlocked_groups (int): leave last n layer groups unlocked (default: 0)
"""
if
not
unlocked_groups
:
# lock full model
for
param
in
self
.
trunk
.
parameters
():
param
.
requires_grad
=
False
if
freeze_bn_stats
:
freeze_batch_norm_2d
(
self
.
trunk
)
else
:
# NOTE: partial freeze requires latest timm (master) branch and is subject to change
try
:
# FIXME import here until API stable and in an official release
from
timm.models.helpers
import
group_parameters
,
group_modules
except
ImportError
:
raise
RuntimeError
(
'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`'
)
matcher
=
self
.
trunk
.
group_matcher
()
gparams
=
group_parameters
(
self
.
trunk
,
matcher
)
max_layer_id
=
max
(
gparams
.
keys
())
max_layer_id
=
max_layer_id
-
unlocked_groups
for
group_idx
in
range
(
max_layer_id
+
1
):
group
=
gparams
[
group_idx
]
for
param
in
group
:
self
.
trunk
.
get_parameter
(
param
).
requires_grad
=
False
if
freeze_bn_stats
:
gmodules
=
group_modules
(
self
.
trunk
,
matcher
,
reverse
=
True
)
gmodules
=
{
k
for
k
,
v
in
gmodules
.
items
()
if
v
<=
max_layer_id
}
freeze_batch_norm_2d
(
self
.
trunk
,
gmodules
)
@
torch
.
jit
.
ignore
def
set_grad_checkpointing
(
self
,
enable
=
True
):
try
:
self
.
trunk
.
set_grad_checkpointing
(
enable
)
except
Exception
as
e
:
logging
.
warning
(
'grad checkpointing not supported for this timm image tower, continuing without...'
)
def
forward
(
self
,
x
,
dense
=
False
):
out
=
{}
x
=
self
.
trunk
.
stem
(
x
)
out
[
'stem'
]
=
x
.
contiguous
()
# os4
for
i
in
range
(
4
):
x
=
self
.
trunk
.
stages
[
i
](
x
)
out
[
f
'res
{
i
+
2
}
'
]
=
x
.
contiguous
()
# res 2 (os4), 3 (os8), 4 (os16), 5 (os32)
x
=
self
.
trunk
.
norm_pre
(
x
)
B
,
C
,
H
,
W
=
x
.
size
()
x
=
rearrange
(
x
,
"B C H W ->B (H W) C"
)
x
=
self
.
visual_prediction_forward_convnext
(
x
)
x
=
rearrange
(
x
,
"B (H W) C ->B C H W"
,
H
=
H
)
out
[
'clip_vis_dense'
]
=
x
.
contiguous
()
return
out
def
visual_prediction_forward_convnext
(
self
,
x
):
batch
,
num_query
,
channel
=
x
.
shape
x
=
x
.
reshape
(
batch
*
num_query
,
channel
,
1
,
1
)
# fake 2D input
x
=
self
.
trunk
.
head
(
x
)
x
=
self
.
head
(
x
)
return
x
.
view
(
batch
,
num_query
,
x
.
shape
[
-
1
])
# B x num_queries x 640
open_clip/src/open_clip/tokenizer.py
0 → 100644
View file @
f55a786e
""" CLIP tokenizer
Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
"""
import
gzip
import
html
import
os
from
functools
import
lru_cache
from
typing
import
Union
,
List
import
ftfy
import
regex
as
re
import
torch
# https://stackoverflow.com/q/62691279
import
os
os
.
environ
[
"TOKENIZERS_PARALLELISM"
]
=
"false"
@
lru_cache
()
def
default_bpe
():
return
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"bpe_simple_vocab_16e6.txt.gz"
)
@
lru_cache
()
def
bytes_to_unicode
():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a significant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs
=
list
(
range
(
ord
(
"!"
),
ord
(
"~"
)
+
1
))
+
list
(
range
(
ord
(
"¡"
),
ord
(
"¬"
)
+
1
))
+
list
(
range
(
ord
(
"®"
),
ord
(
"ÿ"
)
+
1
))
cs
=
bs
[:]
n
=
0
for
b
in
range
(
2
**
8
):
if
b
not
in
bs
:
bs
.
append
(
b
)
cs
.
append
(
2
**
8
+
n
)
n
+=
1
cs
=
[
chr
(
n
)
for
n
in
cs
]
return
dict
(
zip
(
bs
,
cs
))
def
get_pairs
(
word
):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs
=
set
()
prev_char
=
word
[
0
]
for
char
in
word
[
1
:]:
pairs
.
add
((
prev_char
,
char
))
prev_char
=
char
return
pairs
def
basic_clean
(
text
):
text
=
ftfy
.
fix_text
(
text
)
text
=
html
.
unescape
(
html
.
unescape
(
text
))
return
text
.
strip
()
def
whitespace_clean
(
text
):
text
=
re
.
sub
(
r
'\s+'
,
' '
,
text
)
text
=
text
.
strip
()
return
text
class
SimpleTokenizer
(
object
):
def
__init__
(
self
,
bpe_path
:
str
=
default_bpe
(),
special_tokens
=
None
):
self
.
byte_encoder
=
bytes_to_unicode
()
self
.
byte_decoder
=
{
v
:
k
for
k
,
v
in
self
.
byte_encoder
.
items
()}
merges
=
gzip
.
open
(
bpe_path
).
read
().
decode
(
"utf-8"
).
split
(
'
\n
'
)
merges
=
merges
[
1
:
49152
-
256
-
2
+
1
]
merges
=
[
tuple
(
merge
.
split
())
for
merge
in
merges
]
vocab
=
list
(
bytes_to_unicode
().
values
())
vocab
=
vocab
+
[
v
+
'</w>'
for
v
in
vocab
]
for
merge
in
merges
:
vocab
.
append
(
''
.
join
(
merge
))
if
not
special_tokens
:
special_tokens
=
[
'<start_of_text>'
,
'<end_of_text>'
]
else
:
special_tokens
=
[
'<start_of_text>'
,
'<end_of_text>'
]
+
special_tokens
vocab
.
extend
(
special_tokens
)
self
.
encoder
=
dict
(
zip
(
vocab
,
range
(
len
(
vocab
))))
self
.
decoder
=
{
v
:
k
for
k
,
v
in
self
.
encoder
.
items
()}
self
.
bpe_ranks
=
dict
(
zip
(
merges
,
range
(
len
(
merges
))))
self
.
cache
=
{
t
:
t
for
t
in
special_tokens
}
special
=
"|"
.
join
(
special_tokens
)
self
.
pat
=
re
.
compile
(
special
+
r
"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""
,
re
.
IGNORECASE
)
self
.
vocab_size
=
len
(
self
.
encoder
)
self
.
all_special_ids
=
[
self
.
encoder
[
t
]
for
t
in
special_tokens
]
def
bpe
(
self
,
token
):
if
token
in
self
.
cache
:
return
self
.
cache
[
token
]
word
=
tuple
(
token
[:
-
1
])
+
(
token
[
-
1
]
+
'</w>'
,)
pairs
=
get_pairs
(
word
)
if
not
pairs
:
return
token
+
'</w>'
while
True
:
bigram
=
min
(
pairs
,
key
=
lambda
pair
:
self
.
bpe_ranks
.
get
(
pair
,
float
(
'inf'
)))
if
bigram
not
in
self
.
bpe_ranks
:
break
first
,
second
=
bigram
new_word
=
[]
i
=
0
while
i
<
len
(
word
):
try
:
j
=
word
.
index
(
first
,
i
)
new_word
.
extend
(
word
[
i
:
j
])
i
=
j
except
:
new_word
.
extend
(
word
[
i
:])
break
if
word
[
i
]
==
first
and
i
<
len
(
word
)
-
1
and
word
[
i
+
1
]
==
second
:
new_word
.
append
(
first
+
second
)
i
+=
2
else
:
new_word
.
append
(
word
[
i
])
i
+=
1
new_word
=
tuple
(
new_word
)
word
=
new_word
if
len
(
word
)
==
1
:
break
else
:
pairs
=
get_pairs
(
word
)
word
=
' '
.
join
(
word
)
self
.
cache
[
token
]
=
word
return
word
def
encode
(
self
,
text
):
bpe_tokens
=
[]
text
=
whitespace_clean
(
basic_clean
(
text
)).
lower
()
for
token
in
re
.
findall
(
self
.
pat
,
text
):
token
=
''
.
join
(
self
.
byte_encoder
[
b
]
for
b
in
token
.
encode
(
'utf-8'
))
bpe_tokens
.
extend
(
self
.
encoder
[
bpe_token
]
for
bpe_token
in
self
.
bpe
(
token
).
split
(
' '
))
return
bpe_tokens
def
decode
(
self
,
tokens
):
text
=
''
.
join
([
self
.
decoder
[
token
]
for
token
in
tokens
])
text
=
bytearray
([
self
.
byte_decoder
[
c
]
for
c
in
text
]).
decode
(
'utf-8'
,
errors
=
"replace"
).
replace
(
'</w>'
,
' '
)
return
text
_tokenizer
=
SimpleTokenizer
()
def
tokenize
(
texts
:
Union
[
str
,
List
[
str
]],
context_length
:
int
=
77
)
->
torch
.
LongTensor
:
"""
Returns the tokenized representation of given input string(s)
Parameters
----------
texts : Union[str, List[str]]
An input string or a list of input strings to tokenize
context_length : int
The context length to use; all CLIP models use 77 as the context length
Returns
-------
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
"""
if
isinstance
(
texts
,
str
):
texts
=
[
texts
]
sot_token
=
_tokenizer
.
encoder
[
"<start_of_text>"
]
eot_token
=
_tokenizer
.
encoder
[
"<end_of_text>"
]
all_tokens
=
[[
sot_token
]
+
_tokenizer
.
encode
(
text
)
+
[
eot_token
]
for
text
in
texts
]
result
=
torch
.
zeros
(
len
(
all_tokens
),
context_length
,
dtype
=
torch
.
long
)
for
i
,
tokens
in
enumerate
(
all_tokens
):
if
len
(
tokens
)
>
context_length
:
tokens
=
tokens
[:
context_length
]
# Truncate
tokens
[
-
1
]
=
eot_token
result
[
i
,
:
len
(
tokens
)]
=
torch
.
tensor
(
tokens
)
return
result
class
HFTokenizer
:
"HuggingFace tokenizer wrapper"
def
__init__
(
self
,
tokenizer_name
:
str
):
from
transformers
import
AutoTokenizer
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
tokenizer_name
)
def
__call__
(
self
,
texts
:
Union
[
str
,
List
[
str
]],
context_length
:
int
=
77
)
->
torch
.
Tensor
:
# same cleaning as for default tokenizer, except lowercasing
# adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
if
isinstance
(
texts
,
str
):
texts
=
[
texts
]
texts
=
[
whitespace_clean
(
basic_clean
(
text
))
for
text
in
texts
]
input_ids
=
self
.
tokenizer
(
texts
,
return_tensors
=
'pt'
,
max_length
=
context_length
,
padding
=
'max_length'
,
truncation
=
True
).
input_ids
return
input_ids
open_clip/src/open_clip/transform.py
0 → 100644
View file @
f55a786e
import
warnings
from
dataclasses
import
dataclass
,
asdict
from
typing
import
Any
,
Dict
,
Optional
,
Sequence
,
Tuple
,
Union
import
torch
import
torch.nn
as
nn
import
torchvision.transforms.functional
as
F
from
torchvision.transforms
import
Normalize
,
Compose
,
RandomResizedCrop
,
InterpolationMode
,
ToTensor
,
Resize
,
\
CenterCrop
from
.constants
import
OPENAI_DATASET_MEAN
,
OPENAI_DATASET_STD
@
dataclass
class
AugmentationCfg
:
scale
:
Tuple
[
float
,
float
]
=
(
0.9
,
1.0
)
ratio
:
Optional
[
Tuple
[
float
,
float
]]
=
None
color_jitter
:
Optional
[
Union
[
float
,
Tuple
[
float
,
float
,
float
]]]
=
None
interpolation
:
Optional
[
str
]
=
None
re_prob
:
Optional
[
float
]
=
None
re_count
:
Optional
[
int
]
=
None
use_timm
:
bool
=
False
class
ResizeMaxSize
(
nn
.
Module
):
def
__init__
(
self
,
max_size
,
interpolation
=
InterpolationMode
.
BICUBIC
,
fn
=
'max'
,
fill
=
0
):
super
().
__init__
()
if
not
isinstance
(
max_size
,
int
):
raise
TypeError
(
f
"Size should be int. Got
{
type
(
max_size
)
}
"
)
self
.
max_size
=
max_size
self
.
interpolation
=
interpolation
self
.
fn
=
min
if
fn
==
'min'
else
min
self
.
fill
=
fill
def
forward
(
self
,
img
):
if
isinstance
(
img
,
torch
.
Tensor
):
height
,
width
=
img
.
shape
[:
2
]
else
:
width
,
height
=
img
.
size
scale
=
self
.
max_size
/
float
(
max
(
height
,
width
))
if
scale
!=
1.0
:
new_size
=
tuple
(
round
(
dim
*
scale
)
for
dim
in
(
height
,
width
))
img
=
F
.
resize
(
img
,
new_size
,
self
.
interpolation
)
pad_h
=
self
.
max_size
-
new_size
[
0
]
pad_w
=
self
.
max_size
-
new_size
[
1
]
img
=
F
.
pad
(
img
,
padding
=
[
pad_w
//
2
,
pad_h
//
2
,
pad_w
-
pad_w
//
2
,
pad_h
-
pad_h
//
2
],
fill
=
self
.
fill
)
return
img
def
_convert_to_rgb
(
image
):
return
image
.
convert
(
'RGB'
)
def
image_transform
(
image_size
:
int
,
is_train
:
bool
,
mean
:
Optional
[
Tuple
[
float
,
...]]
=
None
,
std
:
Optional
[
Tuple
[
float
,
...]]
=
None
,
resize_longest_max
:
bool
=
False
,
fill_color
:
int
=
0
,
aug_cfg
:
Optional
[
Union
[
Dict
[
str
,
Any
],
AugmentationCfg
]]
=
None
,
):
mean
=
mean
or
OPENAI_DATASET_MEAN
if
not
isinstance
(
mean
,
(
list
,
tuple
)):
mean
=
(
mean
,)
*
3
std
=
std
or
OPENAI_DATASET_STD
if
not
isinstance
(
std
,
(
list
,
tuple
)):
std
=
(
std
,)
*
3
if
isinstance
(
image_size
,
(
list
,
tuple
))
and
image_size
[
0
]
==
image_size
[
1
]:
# for square size, pass size as int so that Resize() uses aspect preserving shortest edge
image_size
=
image_size
[
0
]
if
isinstance
(
aug_cfg
,
dict
):
aug_cfg
=
AugmentationCfg
(
**
aug_cfg
)
else
:
aug_cfg
=
aug_cfg
or
AugmentationCfg
()
normalize
=
Normalize
(
mean
=
mean
,
std
=
std
)
if
is_train
:
aug_cfg_dict
=
{
k
:
v
for
k
,
v
in
asdict
(
aug_cfg
).
items
()
if
v
is
not
None
}
use_timm
=
aug_cfg_dict
.
pop
(
'use_timm'
,
False
)
if
use_timm
:
from
timm.data
import
create_transform
# timm can still be optional
if
isinstance
(
model
.
visual
.
image_size
,
(
tuple
,
list
)):
assert
len
(
model
.
visual
.
image_size
)
>=
2
input_size
=
(
3
,)
+
model
.
visual
.
image_size
[
-
2
:]
else
:
input_size
=
(
3
,
model
.
visual
.
image_size
,
model
.
visual
.
image_size
)
# by default, timm aug randomly alternates bicubic & bilinear for better robustness at inference time
aug_cfg_dict
.
setdefault
(
'interpolation'
,
'random'
)
aug_cfg_dict
.
setdefault
(
'color_jitter'
,
None
)
# disable by default
train_transform
=
create_transform
(
input_size
=
input_size
,
is_training
=
True
,
hflip
=
0.
,
mean
=
image_mean
,
std
=
image_std
,
re_mode
=
'pixel'
,
**
aug_cfg_dict
,
)
else
:
train_transform
=
Compose
([
RandomResizedCrop
(
image_size
,
scale
=
aug_cfg_dict
.
pop
(
'scale'
),
interpolation
=
InterpolationMode
.
BICUBIC
,
),
_convert_to_rgb
,
ToTensor
(),
normalize
,
])
if
aug_cfg_dict
:
warnings
.
warn
(
f
'Unused augmentation cfg items, specify `use_timm` to use (
{
list
(
aug_cfg_dict
.
keys
())
}
).'
)
return
train_transform
else
:
if
resize_longest_max
:
transforms
=
[
ResizeMaxSize
(
image_size
,
fill
=
fill_color
)
]
else
:
transforms
=
[
Resize
(
image_size
,
interpolation
=
InterpolationMode
.
BICUBIC
),
CenterCrop
(
image_size
),
]
transforms
.
extend
([
_convert_to_rgb
,
ToTensor
(),
normalize
,
])
return
Compose
(
transforms
)
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment