Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SED_pytorch
Commits
f55a786e
"lib/runtime/src/prelude.rs" did not exist on "183941fae1250a658807eb9f88076de857d69750"
Commit
f55a786e
authored
Jun 05, 2024
by
luopl
Browse files
Initial commit
parents
Pipeline
#1081
canceled with stages
Changes
181
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3278 additions
and
0 deletions
+3278
-0
open_clip/src/open_clip/transformer.py
open_clip/src/open_clip/transformer.py
+507
-0
open_clip/src/open_clip/utils.py
open_clip/src/open_clip/utils.py
+60
-0
open_clip/src/open_clip/version.py
open_clip/src/open_clip/version.py
+1
-0
open_clip/src/training/.gitignore
open_clip/src/training/.gitignore
+1
-0
open_clip/src/training/__init__.py
open_clip/src/training/__init__.py
+0
-0
open_clip/src/training/data.py
open_clip/src/training/data.py
+514
-0
open_clip/src/training/distributed.py
open_clip/src/training/distributed.py
+137
-0
open_clip/src/training/file_utils.py
open_clip/src/training/file_utils.py
+83
-0
open_clip/src/training/imagenet_zeroshot_data.py
open_clip/src/training/imagenet_zeroshot_data.py
+254
-0
open_clip/src/training/logger.py
open_clip/src/training/logger.py
+26
-0
open_clip/src/training/main.py
open_clip/src/training/main.py
+446
-0
open_clip/src/training/params.py
open_clip/src/training/params.py
+403
-0
open_clip/src/training/precision.py
open_clip/src/training/precision.py
+12
-0
open_clip/src/training/profile.py
open_clip/src/training/profile.py
+158
-0
open_clip/src/training/scheduler.py
open_clip/src/training/scheduler.py
+53
-0
open_clip/src/training/train.py
open_clip/src/training/train.py
+308
-0
open_clip/src/training/zero_shot.py
open_clip/src/training/zero_shot.py
+93
-0
open_clip/tests/test_download_pretrained.py
open_clip/tests/test_download_pretrained.py
+111
-0
open_clip/tests/test_hf_model.py
open_clip/tests/test_hf_model.py
+29
-0
open_clip/tests/test_inference.py
open_clip/tests/test_inference.py
+82
-0
No files found.
open_clip/src/open_clip/transformer.py
0 → 100644
View file @
f55a786e
from
collections
import
OrderedDict
import
math
from
typing
import
Callable
,
Optional
,
Sequence
import
torch
from
torch
import
nn
from
torch.nn
import
functional
as
F
from
torch.utils.checkpoint
import
checkpoint
from
.utils
import
to_2tuple
class
LayerNormFp32
(
nn
.
LayerNorm
):
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
def
forward
(
self
,
x
:
torch
.
Tensor
):
orig_type
=
x
.
dtype
x
=
F
.
layer_norm
(
x
.
to
(
torch
.
float32
),
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
return
x
.
to
(
orig_type
)
class
LayerNorm
(
nn
.
LayerNorm
):
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
def
forward
(
self
,
x
:
torch
.
Tensor
):
orig_type
=
x
.
dtype
x
=
F
.
layer_norm
(
x
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
return
x
.
to
(
orig_type
)
class
QuickGELU
(
nn
.
Module
):
# NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
def
forward
(
self
,
x
:
torch
.
Tensor
):
return
x
*
torch
.
sigmoid
(
1.702
*
x
)
class
LayerScale
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
init_values
=
1e-5
,
inplace
=
False
):
super
().
__init__
()
self
.
inplace
=
inplace
self
.
gamma
=
nn
.
Parameter
(
init_values
*
torch
.
ones
(
dim
))
def
forward
(
self
,
x
):
return
x
.
mul_
(
self
.
gamma
)
if
self
.
inplace
else
x
*
self
.
gamma
class
PatchDropout
(
nn
.
Module
):
"""
https://arxiv.org/abs/2212.00794
"""
def
__init__
(
self
,
prob
,
exclude_first_token
=
True
):
super
().
__init__
()
assert
0
<=
prob
<
1.
self
.
prob
=
prob
self
.
exclude_first_token
=
exclude_first_token
# exclude CLS token
def
forward
(
self
,
x
):
if
not
self
.
training
or
self
.
prob
==
0.
:
return
x
if
self
.
exclude_first_token
:
cls_tokens
,
x
=
x
[:,
:
1
],
x
[:,
1
:]
else
:
cls_tokens
=
torch
.
jit
.
annotate
(
torch
.
Tensor
,
x
[:,
:
1
])
batch
=
x
.
size
()[
0
]
num_tokens
=
x
.
size
()[
1
]
batch_indices
=
torch
.
arange
(
batch
)
batch_indices
=
batch_indices
[...,
None
]
keep_prob
=
1
-
self
.
prob
num_patches_keep
=
max
(
1
,
int
(
num_tokens
*
keep_prob
))
rand
=
torch
.
randn
(
batch
,
num_tokens
)
patch_indices_keep
=
rand
.
topk
(
num_patches_keep
,
dim
=-
1
).
indices
x
=
x
[
batch_indices
,
patch_indices_keep
]
if
self
.
exclude_first_token
:
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
True
,
scaled_cosine
=
False
,
scale_heads
=
False
,
logit_scale_max
=
math
.
log
(
1.
/
0.01
),
attn_drop
=
0.
,
proj_drop
=
0.
):
super
().
__init__
()
self
.
scaled_cosine
=
scaled_cosine
self
.
scale_heads
=
scale_heads
assert
dim
%
num_heads
==
0
,
'dim should be divisible by num_heads'
self
.
num_heads
=
num_heads
self
.
head_dim
=
dim
//
num_heads
self
.
scale
=
self
.
head_dim
**
-
0.5
self
.
logit_scale_max
=
logit_scale_max
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
self
.
in_proj_weight
=
nn
.
Parameter
(
torch
.
randn
((
dim
*
3
,
dim
))
*
self
.
scale
)
if
qkv_bias
:
self
.
in_proj_bias
=
nn
.
Parameter
(
torch
.
zeros
(
dim
*
3
))
else
:
self
.
in_proj_bias
=
None
if
self
.
scaled_cosine
:
self
.
logit_scale
=
nn
.
Parameter
(
torch
.
log
(
10
*
torch
.
ones
((
num_heads
,
1
,
1
))))
else
:
self
.
logit_scale
=
None
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
if
self
.
scale_heads
:
self
.
head_scale
=
nn
.
Parameter
(
torch
.
ones
((
num_heads
,
1
,
1
)))
else
:
self
.
head_scale
=
None
self
.
out_proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
out_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
):
L
,
N
,
C
=
x
.
shape
q
,
k
,
v
=
F
.
linear
(
x
,
self
.
in_proj_weight
,
self
.
in_proj_bias
).
chunk
(
3
,
dim
=-
1
)
q
=
q
.
contiguous
().
view
(
L
,
N
*
self
.
num_heads
,
-
1
).
transpose
(
0
,
1
)
k
=
k
.
contiguous
().
view
(
L
,
N
*
self
.
num_heads
,
-
1
).
transpose
(
0
,
1
)
v
=
v
.
contiguous
().
view
(
L
,
N
*
self
.
num_heads
,
-
1
).
transpose
(
0
,
1
)
if
self
.
logit_scale
is
not
None
:
attn
=
torch
.
bmm
(
F
.
normalize
(
q
,
dim
=-
1
),
F
.
normalize
(
k
,
dim
=-
1
).
transpose
(
-
1
,
-
2
))
logit_scale
=
torch
.
clamp
(
self
.
logit_scale
,
max
=
self
.
logit_scale_max
).
exp
()
attn
=
attn
.
view
(
N
,
self
.
num_heads
,
L
,
L
)
*
logit_scale
attn
=
attn
.
view
(
-
1
,
L
,
L
)
else
:
q
=
q
*
self
.
scale
attn
=
torch
.
bmm
(
q
,
k
.
transpose
(
-
1
,
-
2
))
if
attn_mask
is
not
None
:
if
attn_mask
.
dtype
==
torch
.
bool
:
new_attn_mask
=
torch
.
zeros_like
(
attn_mask
,
dtype
=
q
.
dtype
)
new_attn_mask
.
masked_fill_
(
attn_mask
,
float
(
"-inf"
))
attn_mask
=
new_attn_mask
attn
+=
attn_mask
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
torch
.
bmm
(
attn
,
v
)
if
self
.
head_scale
is
not
None
:
x
=
x
.
view
(
N
,
self
.
num_heads
,
L
,
C
)
*
self
.
head_scale
x
=
x
.
view
(
-
1
,
L
,
C
)
x
=
x
.
transpose
(
0
,
1
).
reshape
(
L
,
N
,
C
)
x
=
self
.
out_proj
(
x
)
x
=
self
.
out_drop
(
x
)
return
x
class
ResidualAttentionBlock
(
nn
.
Module
):
def
__init__
(
self
,
d_model
:
int
,
n_head
:
int
,
mlp_ratio
:
float
=
4.0
,
ls_init_value
:
float
=
None
,
act_layer
:
Callable
=
nn
.
GELU
,
norm_layer
:
Callable
=
LayerNorm
,
):
super
().
__init__
()
self
.
ln_1
=
norm_layer
(
d_model
)
self
.
attn
=
nn
.
MultiheadAttention
(
d_model
,
n_head
)
self
.
ls_1
=
LayerScale
(
d_model
,
ls_init_value
)
if
ls_init_value
is
not
None
else
nn
.
Identity
()
self
.
ln_2
=
norm_layer
(
d_model
)
mlp_width
=
int
(
d_model
*
mlp_ratio
)
self
.
mlp
=
nn
.
Sequential
(
OrderedDict
([
(
"c_fc"
,
nn
.
Linear
(
d_model
,
mlp_width
)),
(
"gelu"
,
act_layer
()),
(
"c_proj"
,
nn
.
Linear
(
mlp_width
,
d_model
))
]))
self
.
ls_2
=
LayerScale
(
d_model
,
ls_init_value
)
if
ls_init_value
is
not
None
else
nn
.
Identity
()
def
attention
(
self
,
x
:
torch
.
Tensor
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
):
attn_mask
=
attn_mask
.
to
(
x
.
dtype
)
if
attn_mask
is
not
None
else
None
return
self
.
attn
(
x
,
x
,
x
,
need_weights
=
False
,
attn_mask
=
attn_mask
)[
0
]
def
forward
(
self
,
x
:
torch
.
Tensor
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
):
x
=
x
+
self
.
ls_1
(
self
.
attention
(
self
.
ln_1
(
x
),
attn_mask
=
attn_mask
))
x
=
x
+
self
.
ls_2
(
self
.
mlp
(
self
.
ln_2
(
x
)))
return
x
def
forward_dense
(
self
,
x
):
y
=
self
.
ln_1
(
x
)
y
=
F
.
linear
(
y
,
self
.
attn
.
in_proj_weight
,
self
.
attn
.
in_proj_bias
)
L
,
N
,
D
=
y
.
shape
# L N 3D
y
=
y
.
reshape
(
L
,
N
,
3
,
D
//
3
).
permute
(
2
,
1
,
0
,
3
).
reshape
(
3
*
N
,
L
,
D
//
3
)
y
=
F
.
linear
(
y
,
self
.
attn
.
out_proj
.
weight
,
self
.
attn
.
out_proj
.
bias
)
q
,
k
,
v
=
y
.
tensor_split
(
3
,
dim
=
0
)
v
=
v
.
transpose
(
1
,
0
)
+
x
# L N D
v
=
v
+
self
.
mlp
(
self
.
ln_2
(
v
))
return
v
class
CustomResidualAttentionBlock
(
nn
.
Module
):
def
__init__
(
self
,
d_model
:
int
,
n_head
:
int
,
mlp_ratio
:
float
=
4.0
,
ls_init_value
:
float
=
None
,
act_layer
:
Callable
=
nn
.
GELU
,
norm_layer
:
Callable
=
LayerNorm
,
scale_cosine_attn
:
bool
=
False
,
scale_heads
:
bool
=
False
,
scale_attn
:
bool
=
False
,
scale_fc
:
bool
=
False
,
):
super
().
__init__
()
self
.
ln_1
=
norm_layer
(
d_model
)
self
.
attn
=
Attention
(
d_model
,
n_head
,
scaled_cosine
=
scale_cosine_attn
,
scale_heads
=
scale_heads
,
)
self
.
ln_attn
=
norm_layer
(
d_model
)
if
scale_attn
else
nn
.
Identity
()
self
.
ls_1
=
LayerScale
(
d_model
,
ls_init_value
)
if
ls_init_value
is
not
None
else
nn
.
Identity
()
self
.
ln_2
=
norm_layer
(
d_model
)
mlp_width
=
int
(
d_model
*
mlp_ratio
)
self
.
mlp
=
nn
.
Sequential
(
OrderedDict
([
(
"c_fc"
,
nn
.
Linear
(
d_model
,
mlp_width
)),
(
'ln'
,
norm_layer
(
mlp_width
)
if
scale_fc
else
nn
.
Identity
()),
(
"gelu"
,
act_layer
()),
(
"c_proj"
,
nn
.
Linear
(
mlp_width
,
d_model
))
]))
self
.
ls_2
=
LayerScale
(
d_model
,
ls_init_value
)
if
ls_init_value
is
not
None
else
nn
.
Identity
()
def
forward
(
self
,
x
:
torch
.
Tensor
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
):
x
=
x
+
self
.
ls_1
(
self
.
ln_attn
(
self
.
attn
(
self
.
ln_1
(
x
),
attn_mask
=
attn_mask
)))
x
=
x
+
self
.
ls_2
(
self
.
mlp
(
self
.
ln_2
(
x
)))
return
x
class
Transformer
(
nn
.
Module
):
def
__init__
(
self
,
width
:
int
,
layers
:
int
,
heads
:
int
,
mlp_ratio
:
float
=
4.0
,
ls_init_value
:
float
=
None
,
act_layer
:
Callable
=
nn
.
GELU
,
norm_layer
:
Callable
=
LayerNorm
,
):
super
().
__init__
()
self
.
width
=
width
self
.
layers
=
layers
self
.
grad_checkpointing
=
False
self
.
resblocks
=
nn
.
ModuleList
([
ResidualAttentionBlock
(
width
,
heads
,
mlp_ratio
,
ls_init_value
=
ls_init_value
,
act_layer
=
act_layer
,
norm_layer
=
norm_layer
)
for
_
in
range
(
layers
)
])
def
get_cast_dtype
(
self
)
->
torch
.
dtype
:
return
self
.
resblocks
[
0
].
mlp
.
c_fc
.
weight
.
dtype
def
forward
(
self
,
x
:
torch
.
Tensor
,
attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
dense
=
False
):
for
i
,
r
in
enumerate
(
self
.
resblocks
):
if
self
.
grad_checkpointing
and
not
torch
.
jit
.
is_scripting
():
x
=
checkpoint
(
r
,
x
,
attn_mask
)
else
:
if
dense
and
i
==
self
.
layers
-
1
:
x
=
r
.
forward_dense
(
x
)
else
:
x
=
r
(
x
,
attn_mask
=
attn_mask
)
return
x
class
VisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
image_size
:
int
,
patch_size
:
int
,
width
:
int
,
layers
:
int
,
heads
:
int
,
mlp_ratio
:
float
,
ls_init_value
:
float
=
None
,
global_average_pool
:
bool
=
False
,
output_dim
:
int
=
512
,
patch_dropout
:
float
=
0.
,
act_layer
:
Callable
=
nn
.
GELU
,
norm_layer
:
Callable
=
LayerNorm
,
):
super
().
__init__
()
self
.
image_size
=
to_2tuple
(
image_size
)
self
.
patch_size
=
to_2tuple
(
patch_size
)
self
.
grid_size
=
(
self
.
image_size
[
0
]
//
self
.
patch_size
[
0
],
self
.
image_size
[
1
]
//
self
.
patch_size
[
1
])
self
.
output_dim
=
output_dim
self
.
conv1
=
nn
.
Conv2d
(
in_channels
=
3
,
out_channels
=
width
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
bias
=
False
)
scale
=
width
**
-
0.5
self
.
class_embedding
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
width
))
self
.
positional_embedding
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
self
.
grid_size
[
0
]
*
self
.
grid_size
[
1
]
+
1
,
width
))
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self
.
patch_dropout
=
PatchDropout
(
patch_dropout
)
if
patch_dropout
>
0.
else
nn
.
Identity
()
self
.
ln_pre
=
norm_layer
(
width
)
self
.
transformer
=
Transformer
(
width
,
layers
,
heads
,
mlp_ratio
,
ls_init_value
=
ls_init_value
,
act_layer
=
act_layer
,
norm_layer
=
norm_layer
,
)
self
.
global_average_pool
=
global_average_pool
self
.
ln_post
=
norm_layer
(
width
)
self
.
proj
=
nn
.
Parameter
(
scale
*
torch
.
randn
(
width
,
output_dim
))
self
.
init_parameters
()
def
lock
(
self
,
unlocked_groups
=
0
,
freeze_bn_stats
=
False
):
for
param
in
self
.
parameters
():
param
.
requires_grad
=
False
if
unlocked_groups
!=
0
:
groups
=
[
[
self
.
conv1
,
self
.
class_embedding
,
self
.
positional_embedding
,
self
.
ln_pre
,
],
*
self
.
transformer
.
resblocks
[:
-
1
],
[
self
.
transformer
.
resblocks
[
-
1
],
self
.
ln_post
,
],
self
.
proj
,
]
def
_unlock
(
x
):
if
isinstance
(
x
,
Sequence
):
for
g
in
x
:
_unlock
(
g
)
else
:
if
isinstance
(
x
,
torch
.
nn
.
Parameter
):
x
.
requires_grad
=
True
else
:
for
p
in
x
.
parameters
():
p
.
requires_grad
=
True
_unlock
(
groups
[
-
unlocked_groups
:])
def
init_parameters
(
self
):
# FIXME OpenAI CLIP did not define an init for the VisualTransformer
# TODO experiment if default PyTorch init, below, or alternate init is best.
# nn.init.normal_(self.class_embedding, std=self.scale)
# nn.init.normal_(self.positional_embedding, std=self.scale)
#
# proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
# attn_std = self.transformer.width ** -0.5
# fc_std = (2 * self.transformer.width) ** -0.5
# for block in self.transformer.resblocks:
# nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
# nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
# nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
# nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
#
# if self.text_projection is not None:
# nn.init.normal_(self.text_projection, std=self.scale)
pass
@
torch
.
jit
.
ignore
def
set_grad_checkpointing
(
self
,
enable
=
True
):
self
.
transformer
.
grad_checkpointing
=
enable
def
forward
(
self
,
x
:
torch
.
Tensor
,
dense
=
False
):
x
=
self
.
conv1
(
x
)
# shape = [*, width, grid, grid]
x
=
x
.
reshape
(
x
.
shape
[
0
],
x
.
shape
[
1
],
-
1
)
# shape = [*, width, grid ** 2]
x
=
x
.
permute
(
0
,
2
,
1
)
# shape = [*, grid ** 2, width]
x
=
torch
.
cat
(
[
self
.
class_embedding
.
to
(
x
.
dtype
)
+
torch
.
zeros
(
x
.
shape
[
0
],
1
,
x
.
shape
[
-
1
],
dtype
=
x
.
dtype
,
device
=
x
.
device
),
x
],
dim
=
1
)
# shape = [*, grid ** 2 + 1, width]
x
=
x
+
self
.
positional_embedding
.
to
(
x
.
dtype
)
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x
=
self
.
patch_dropout
(
x
)
x
=
self
.
ln_pre
(
x
)
x
=
x
.
permute
(
1
,
0
,
2
)
# NLD -> LND
x
=
self
.
transformer
(
x
,
dense
=
dense
)
x
=
x
.
permute
(
1
,
0
,
2
)
# LND -> NLD
if
self
.
global_average_pool
:
x
=
x
.
mean
(
dim
=
1
)
elif
dense
:
x
=
x
else
:
x
=
x
[:,
0
]
x
=
self
.
ln_post
(
x
)
if
self
.
proj
is
not
None
:
x
=
x
@
self
.
proj
return
x
class
TextTransformer
(
nn
.
Module
):
def
__init__
(
self
,
context_length
:
int
=
77
,
vocab_size
:
int
=
49408
,
width
:
int
=
512
,
heads
:
int
=
8
,
layers
:
int
=
12
,
ls_init_value
:
float
=
None
,
output_dim
:
int
=
512
,
act_layer
:
Callable
=
nn
.
GELU
,
norm_layer
:
Callable
=
LayerNorm
,
):
super
().
__init__
()
self
.
context_length
=
context_length
self
.
vocab_size
=
vocab_size
self
.
width
=
width
self
.
output_dim
=
output_dim
self
.
token_embedding
=
nn
.
Embedding
(
vocab_size
,
width
)
self
.
positional_embedding
=
nn
.
Parameter
(
torch
.
empty
(
self
.
context_length
,
width
))
self
.
transformer
=
Transformer
(
width
=
width
,
layers
=
layers
,
heads
=
heads
,
ls_init_value
=
ls_init_value
,
act_layer
=
act_layer
,
norm_layer
=
norm_layer
,
)
self
.
ln_final
=
norm_layer
(
width
)
self
.
text_projection
=
nn
.
Parameter
(
torch
.
empty
(
width
,
output_dim
))
self
.
register_buffer
(
'attn_mask'
,
self
.
build_attention_mask
(),
persistent
=
False
)
self
.
init_parameters
()
def
init_parameters
(
self
):
nn
.
init
.
normal_
(
self
.
token_embedding
.
weight
,
std
=
0.02
)
nn
.
init
.
normal_
(
self
.
positional_embedding
,
std
=
0.01
)
proj_std
=
(
self
.
transformer
.
width
**
-
0.5
)
*
((
2
*
self
.
transformer
.
layers
)
**
-
0.5
)
attn_std
=
self
.
transformer
.
width
**
-
0.5
fc_std
=
(
2
*
self
.
transformer
.
width
)
**
-
0.5
for
block
in
self
.
transformer
.
resblocks
:
nn
.
init
.
normal_
(
block
.
attn
.
in_proj_weight
,
std
=
attn_std
)
nn
.
init
.
normal_
(
block
.
attn
.
out_proj
.
weight
,
std
=
proj_std
)
nn
.
init
.
normal_
(
block
.
mlp
.
c_fc
.
weight
,
std
=
fc_std
)
nn
.
init
.
normal_
(
block
.
mlp
.
c_proj
.
weight
,
std
=
proj_std
)
if
self
.
text_projection
is
not
None
:
nn
.
init
.
normal_
(
self
.
text_projection
,
std
=
self
.
transformer
.
width
**
-
0.5
)
@
torch
.
jit
.
ignore
def
set_grad_checkpointing
(
self
,
enable
=
True
):
self
.
transformer
.
grad_checkpointing
=
enable
def
build_attention_mask
(
self
):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask
=
torch
.
empty
(
self
.
context_length
,
self
.
context_length
)
mask
.
fill_
(
float
(
"-inf"
))
mask
.
triu_
(
1
)
# zero out the lower diagonal
return
mask
def
forward
(
self
,
text
):
cast_dtype
=
self
.
transformer
.
get_cast_dtype
()
x
=
self
.
token_embedding
(
text
).
to
(
cast_dtype
)
# [batch_size, n_ctx, d_model]
x
=
x
+
self
.
positional_embedding
.
to
(
cast_dtype
)
x
=
x
.
permute
(
1
,
0
,
2
)
# NLD -> LND
x
=
self
.
transformer
(
x
,
attn_mask
=
self
.
attn_mask
)
x
=
x
.
permute
(
1
,
0
,
2
)
# LND -> NLD
x
=
self
.
ln_final
(
x
)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x
=
x
[
torch
.
arange
(
x
.
shape
[
0
]),
text
.
argmax
(
dim
=-
1
)]
@
self
.
text_projection
return
x
open_clip/src/open_clip/utils.py
0 → 100644
View file @
f55a786e
from
itertools
import
repeat
import
collections.abc
from
torch
import
nn
as
nn
from
torchvision.ops.misc
import
FrozenBatchNorm2d
def
freeze_batch_norm_2d
(
module
,
module_match
=
{},
name
=
''
):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
module_match (dict): Dictionary of full module names to freeze (all if empty)
name (str): Full module name (prefix)
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res
=
module
is_match
=
True
if
module_match
:
is_match
=
name
in
module_match
if
is_match
and
isinstance
(
module
,
(
nn
.
modules
.
batchnorm
.
BatchNorm2d
,
nn
.
modules
.
batchnorm
.
SyncBatchNorm
)):
res
=
FrozenBatchNorm2d
(
module
.
num_features
)
res
.
num_features
=
module
.
num_features
res
.
affine
=
module
.
affine
if
module
.
affine
:
res
.
weight
.
data
=
module
.
weight
.
data
.
clone
().
detach
()
res
.
bias
.
data
=
module
.
bias
.
data
.
clone
().
detach
()
res
.
running_mean
.
data
=
module
.
running_mean
.
data
res
.
running_var
.
data
=
module
.
running_var
.
data
res
.
eps
=
module
.
eps
else
:
for
child_name
,
child
in
module
.
named_children
():
full_child_name
=
'.'
.
join
([
name
,
child_name
])
if
name
else
child_name
new_child
=
freeze_batch_norm_2d
(
child
,
module_match
,
full_child_name
)
if
new_child
is
not
child
:
res
.
add_module
(
child_name
,
new_child
)
return
res
# From PyTorch internals
def
_ntuple
(
n
):
def
parse
(
x
):
if
isinstance
(
x
,
collections
.
abc
.
Iterable
):
return
x
return
tuple
(
repeat
(
x
,
n
))
return
parse
to_1tuple
=
_ntuple
(
1
)
to_2tuple
=
_ntuple
(
2
)
to_3tuple
=
_ntuple
(
3
)
to_4tuple
=
_ntuple
(
4
)
to_ntuple
=
lambda
n
,
x
:
_ntuple
(
n
)(
x
)
open_clip/src/open_clip/version.py
0 → 100644
View file @
f55a786e
__version__
=
'2.10.1'
open_clip/src/training/.gitignore
0 → 100644
View file @
f55a786e
logs/
open_clip/src/training/__init__.py
0 → 100644
View file @
f55a786e
open_clip/src/training/data.py
0 → 100644
View file @
f55a786e
import
ast
import
json
import
logging
import
math
import
os
import
random
import
sys
import
time
from
dataclasses
import
dataclass
from
multiprocessing
import
Value
import
numpy
as
np
import
pandas
as
pd
import
torch
import
torchvision.datasets
as
datasets
import
webdataset
as
wds
from
PIL
import
Image
from
torch.utils.data
import
Dataset
,
DataLoader
,
SubsetRandomSampler
,
IterableDataset
,
get_worker_info
from
torch.utils.data.distributed
import
DistributedSampler
from
webdataset.filters
import
_shuffle
from
webdataset.tariterators
import
base_plus_ext
,
url_opener
,
tar_file_expander
,
valid_sample
try
:
import
horovod.torch
as
hvd
except
ImportError
:
hvd
=
None
class
CsvDataset
(
Dataset
):
def
__init__
(
self
,
input_filename
,
transforms
,
img_key
,
caption_key
,
sep
=
"
\t
"
,
tokenizer
=
None
):
logging
.
debug
(
f
'Loading csv data from
{
input_filename
}
.'
)
df
=
pd
.
read_csv
(
input_filename
,
sep
=
sep
)
self
.
images
=
df
[
img_key
].
tolist
()
self
.
captions
=
df
[
caption_key
].
tolist
()
self
.
transforms
=
transforms
logging
.
debug
(
'Done loading data.'
)
self
.
tokenize
=
tokenizer
def
__len__
(
self
):
return
len
(
self
.
captions
)
def
__getitem__
(
self
,
idx
):
images
=
self
.
transforms
(
Image
.
open
(
str
(
self
.
images
[
idx
])))
texts
=
self
.
tokenize
([
str
(
self
.
captions
[
idx
])])[
0
]
return
images
,
texts
class
SharedEpoch
:
def
__init__
(
self
,
epoch
:
int
=
0
):
self
.
shared_epoch
=
Value
(
'i'
,
epoch
)
def
set_value
(
self
,
epoch
):
self
.
shared_epoch
.
value
=
epoch
def
get_value
(
self
):
return
self
.
shared_epoch
.
value
@
dataclass
class
DataInfo
:
dataloader
:
DataLoader
sampler
:
DistributedSampler
=
None
shared_epoch
:
SharedEpoch
=
None
def
set_epoch
(
self
,
epoch
):
if
self
.
shared_epoch
is
not
None
:
self
.
shared_epoch
.
set_value
(
epoch
)
if
self
.
sampler
is
not
None
and
isinstance
(
self
.
sampler
,
DistributedSampler
):
self
.
sampler
.
set_epoch
(
epoch
)
def
get_dataset_size
(
shards
):
shards_list
=
wds
.
shardlists
.
expand_urls
(
shards
)
dir_path
=
os
.
path
.
dirname
(
shards_list
[
0
])
sizes_filename
=
os
.
path
.
join
(
dir_path
,
'sizes.json'
)
len_filename
=
os
.
path
.
join
(
dir_path
,
'__len__'
)
if
os
.
path
.
exists
(
sizes_filename
):
sizes
=
json
.
load
(
open
(
sizes_filename
,
'r'
))
total_size
=
sum
([
int
(
sizes
[
os
.
path
.
basename
(
shard
)])
for
shard
in
shards_list
])
elif
os
.
path
.
exists
(
len_filename
):
# FIXME this used to be eval(open(...)) but that seemed rather unsafe
total_size
=
ast
.
literal_eval
(
open
(
len_filename
,
'r'
).
read
())
else
:
total_size
=
None
# num samples undefined
# some common dataset sizes (at time of authors last download)
# CC3M (train): 2905954
# CC12M: 10968539
# LAION-400M: 407332084
# LAION-2B (english): 2170337258
num_shards
=
len
(
shards_list
)
return
total_size
,
num_shards
def
get_imagenet
(
args
,
preprocess_fns
,
split
):
assert
split
in
[
"train"
,
"val"
,
"v2"
]
is_train
=
split
==
"train"
preprocess_train
,
preprocess_val
=
preprocess_fns
if
split
==
"v2"
:
from
imagenetv2_pytorch
import
ImageNetV2Dataset
dataset
=
ImageNetV2Dataset
(
location
=
args
.
imagenet_v2
,
transform
=
preprocess_val
)
else
:
if
is_train
:
data_path
=
args
.
imagenet_train
preprocess_fn
=
preprocess_train
else
:
data_path
=
args
.
imagenet_val
preprocess_fn
=
preprocess_val
assert
data_path
dataset
=
datasets
.
ImageFolder
(
data_path
,
transform
=
preprocess_fn
)
if
is_train
:
idxs
=
np
.
zeros
(
len
(
dataset
.
targets
))
target_array
=
np
.
array
(
dataset
.
targets
)
k
=
50
for
c
in
range
(
1000
):
m
=
target_array
==
c
n
=
len
(
idxs
[
m
])
arr
=
np
.
zeros
(
n
)
arr
[:
k
]
=
1
np
.
random
.
shuffle
(
arr
)
idxs
[
m
]
=
arr
idxs
=
idxs
.
astype
(
'int'
)
sampler
=
SubsetRandomSampler
(
np
.
where
(
idxs
)[
0
])
else
:
sampler
=
None
dataloader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
num_workers
=
args
.
workers
,
sampler
=
sampler
,
)
return
DataInfo
(
dataloader
=
dataloader
,
sampler
=
sampler
)
def
count_samples
(
dataloader
):
os
.
environ
[
"WDS_EPOCH"
]
=
"0"
n_elements
,
n_batches
=
0
,
0
for
images
,
texts
in
dataloader
:
n_batches
+=
1
n_elements
+=
len
(
images
)
assert
len
(
images
)
==
len
(
texts
)
return
n_elements
,
n_batches
def
filter_no_caption_or_no_image
(
sample
):
has_caption
=
(
'txt'
in
sample
)
has_image
=
(
'png'
in
sample
or
'jpg'
in
sample
or
'jpeg'
in
sample
or
'webp'
in
sample
)
return
has_caption
and
has_image
def
log_and_continue
(
exn
):
"""Call in an exception handler to ignore any exception, issue a warning, and continue."""
logging
.
warning
(
f
'Handling webdataset error (
{
repr
(
exn
)
}
). Ignoring.'
)
return
True
def
group_by_keys_nothrow
(
data
,
keys
=
base_plus_ext
,
lcase
=
True
,
suffixes
=
None
,
handler
=
None
):
"""Return function over iterator that groups key, value pairs into samples.
:param keys: function that splits the key into key and extension (base_plus_ext)
:param lcase: convert suffixes to lower case (Default value = True)
"""
current_sample
=
None
for
filesample
in
data
:
assert
isinstance
(
filesample
,
dict
)
fname
,
value
=
filesample
[
"fname"
],
filesample
[
"data"
]
prefix
,
suffix
=
keys
(
fname
)
if
prefix
is
None
:
continue
if
lcase
:
suffix
=
suffix
.
lower
()
# FIXME webdataset version throws if suffix in current_sample, but we have a potential for
# this happening in the current LAION400m dataset if a tar ends with same prefix as the next
# begins, rare, but can happen since prefix aren't unique across tar files in that dataset
if
current_sample
is
None
or
prefix
!=
current_sample
[
"__key__"
]
or
suffix
in
current_sample
:
if
valid_sample
(
current_sample
):
yield
current_sample
current_sample
=
dict
(
__key__
=
prefix
,
__url__
=
filesample
[
"__url__"
])
if
suffixes
is
None
or
suffix
in
suffixes
:
current_sample
[
suffix
]
=
value
if
valid_sample
(
current_sample
):
yield
current_sample
def
tarfile_to_samples_nothrow
(
src
,
handler
=
log_and_continue
):
# NOTE this is a re-impl of the webdataset impl with group_by_keys that doesn't throw
streams
=
url_opener
(
src
,
handler
=
handler
)
files
=
tar_file_expander
(
streams
,
handler
=
handler
)
samples
=
group_by_keys_nothrow
(
files
,
handler
=
handler
)
return
samples
def
pytorch_worker_seed
(
increment
=
0
):
"""get dataloader worker seed from pytorch"""
worker_info
=
get_worker_info
()
if
worker_info
is
not
None
:
# favour using the seed already created for pytorch dataloader workers if it exists
seed
=
worker_info
.
seed
if
increment
:
# space out seed increments so they can't overlap across workers in different iterations
seed
+=
increment
*
max
(
1
,
worker_info
.
num_workers
)
return
seed
# fallback to wds rank based seed
return
wds
.
utils
.
pytorch_worker_seed
()
_SHARD_SHUFFLE_SIZE
=
2000
_SHARD_SHUFFLE_INITIAL
=
500
_SAMPLE_SHUFFLE_SIZE
=
5000
_SAMPLE_SHUFFLE_INITIAL
=
1000
class
detshuffle2
(
wds
.
PipelineStage
):
def
__init__
(
self
,
bufsize
=
1000
,
initial
=
100
,
seed
=
0
,
epoch
=-
1
,
):
self
.
bufsize
=
bufsize
self
.
initial
=
initial
self
.
seed
=
seed
self
.
epoch
=
epoch
def
run
(
self
,
src
):
if
isinstance
(
self
.
epoch
,
SharedEpoch
):
epoch
=
self
.
epoch
.
get_value
()
else
:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self
.
epoch
+=
1
epoch
=
self
.
epoch
rng
=
random
.
Random
()
if
self
.
seed
<
0
:
# If seed is negative, we use the worker's seed, this will be different across all nodes/workers
seed
=
pytorch_worker_seed
(
epoch
)
else
:
# This seed to be deterministic AND the same across all nodes/workers in each epoch
seed
=
self
.
seed
+
epoch
rng
.
seed
(
seed
)
return
_shuffle
(
src
,
self
.
bufsize
,
self
.
initial
,
rng
)
class
ResampledShards2
(
IterableDataset
):
"""An iterable dataset yielding a list of urls."""
def
__init__
(
self
,
urls
,
nshards
=
sys
.
maxsize
,
worker_seed
=
None
,
deterministic
=
False
,
epoch
=-
1
,
):
"""Sample shards from the shard list with replacement.
:param urls: a list of URLs as a Python list or brace notation string
"""
super
().
__init__
()
urls
=
wds
.
shardlists
.
expand_urls
(
urls
)
self
.
urls
=
urls
assert
isinstance
(
self
.
urls
[
0
],
str
)
self
.
nshards
=
nshards
self
.
rng
=
random
.
Random
()
self
.
worker_seed
=
worker_seed
self
.
deterministic
=
deterministic
self
.
epoch
=
epoch
def
__iter__
(
self
):
"""Return an iterator over the shards."""
if
isinstance
(
self
.
epoch
,
SharedEpoch
):
epoch
=
self
.
epoch
.
get_value
()
else
:
# NOTE: this is epoch tracking is problematic in a multiprocess (dataloader workers or train)
# situation as different workers may wrap at different times (or not at all).
self
.
epoch
+=
1
epoch
=
self
.
epoch
if
self
.
deterministic
:
# reset seed w/ epoch if deterministic
if
self
.
worker_seed
is
None
:
# pytorch worker seed should be deterministic due to being init by arg.seed + rank + worker id
seed
=
pytorch_worker_seed
(
epoch
)
else
:
seed
=
self
.
worker_seed
()
+
epoch
self
.
rng
.
seed
(
seed
)
for
_
in
range
(
self
.
nshards
):
yield
dict
(
url
=
self
.
rng
.
choice
(
self
.
urls
))
def
get_wds_dataset
(
args
,
preprocess_img
,
is_train
,
epoch
=
0
,
floor
=
False
,
tokenizer
=
None
):
input_shards
=
args
.
train_data
if
is_train
else
args
.
val_data
assert
input_shards
is
not
None
resampled
=
getattr
(
args
,
'dataset_resampled'
,
False
)
and
is_train
num_samples
,
num_shards
=
get_dataset_size
(
input_shards
)
if
not
num_samples
:
if
is_train
:
num_samples
=
args
.
train_num_samples
if
not
num_samples
:
raise
RuntimeError
(
'Currently, number of dataset samples must be specified for training dataset. '
'Please specify via `--train-num-samples` if no dataset length info present.'
)
else
:
num_samples
=
args
.
val_num_samples
or
0
# eval will just exhaust the iterator if not specified
shared_epoch
=
SharedEpoch
(
epoch
=
epoch
)
# create a shared epoch store to sync epoch to dataloader worker proc
if
resampled
:
pipeline
=
[
ResampledShards2
(
input_shards
,
deterministic
=
True
,
epoch
=
shared_epoch
)]
else
:
pipeline
=
[
wds
.
SimpleShardList
(
input_shards
)]
# at this point we have an iterator over all the shards
if
is_train
:
if
not
resampled
:
pipeline
.
extend
([
detshuffle2
(
bufsize
=
_SHARD_SHUFFLE_SIZE
,
initial
=
_SHARD_SHUFFLE_INITIAL
,
seed
=
args
.
seed
,
epoch
=
shared_epoch
,
),
wds
.
split_by_node
,
wds
.
split_by_worker
,
])
pipeline
.
extend
([
# at this point, we have an iterator over the shards assigned to each worker at each node
tarfile_to_samples_nothrow
,
# wds.tarfile_to_samples(handler=log_and_continue),
wds
.
shuffle
(
bufsize
=
_SAMPLE_SHUFFLE_SIZE
,
initial
=
_SAMPLE_SHUFFLE_INITIAL
,
),
])
else
:
pipeline
.
extend
([
wds
.
split_by_worker
,
# at this point, we have an iterator over the shards assigned to each worker
wds
.
tarfile_to_samples
(
handler
=
log_and_continue
),
])
pipeline
.
extend
([
wds
.
select
(
filter_no_caption_or_no_image
),
wds
.
decode
(
"pilrgb"
,
handler
=
log_and_continue
),
wds
.
rename
(
image
=
"jpg;png;jpeg;webp"
,
text
=
"txt"
),
wds
.
map_dict
(
image
=
preprocess_img
,
text
=
lambda
text
:
tokenizer
(
text
)[
0
]),
wds
.
to_tuple
(
"image"
,
"text"
),
wds
.
batched
(
args
.
batch_size
,
partial
=
not
is_train
),
])
dataset
=
wds
.
DataPipeline
(
*
pipeline
)
if
is_train
:
if
not
resampled
:
assert
num_shards
>=
args
.
workers
*
args
.
world_size
,
'number of shards must be >= total workers'
# roll over and repeat a few samples to get same number of full batches on each node
round_fn
=
math
.
floor
if
floor
else
math
.
ceil
global_batch_size
=
args
.
batch_size
*
args
.
world_size
num_batches
=
round_fn
(
num_samples
/
global_batch_size
)
num_workers
=
max
(
1
,
args
.
workers
)
num_worker_batches
=
round_fn
(
num_batches
/
num_workers
)
# per dataloader worker
num_batches
=
num_worker_batches
*
num_workers
num_samples
=
num_batches
*
global_batch_size
dataset
=
dataset
.
with_epoch
(
num_worker_batches
)
# each worker is iterating over this
else
:
# last batches are partial, eval is done on single (master) node
num_batches
=
math
.
ceil
(
num_samples
/
args
.
batch_size
)
dataloader
=
wds
.
WebLoader
(
dataset
,
batch_size
=
None
,
shuffle
=
False
,
num_workers
=
args
.
workers
,
persistent_workers
=
True
,
)
# FIXME not clear which approach is better, with_epoch before vs after dataloader?
# hoping to resolve via https://github.com/webdataset/webdataset/issues/169
# if is_train:
# # roll over and repeat a few samples to get same number of full batches on each node
# global_batch_size = args.batch_size * args.world_size
# num_batches = math.ceil(num_samples / global_batch_size)
# num_workers = max(1, args.workers)
# num_batches = math.ceil(num_batches / num_workers) * num_workers
# num_samples = num_batches * global_batch_size
# dataloader = dataloader.with_epoch(num_batches)
# else:
# # last batches are partial, eval is done on single (master) node
# num_batches = math.ceil(num_samples / args.batch_size)
# add meta-data to dataloader instance for convenience
dataloader
.
num_batches
=
num_batches
dataloader
.
num_samples
=
num_samples
return
DataInfo
(
dataloader
=
dataloader
,
shared_epoch
=
shared_epoch
)
def
get_csv_dataset
(
args
,
preprocess_fn
,
is_train
,
epoch
=
0
,
tokenizer
=
None
):
input_filename
=
args
.
train_data
if
is_train
else
args
.
val_data
assert
input_filename
dataset
=
CsvDataset
(
input_filename
,
preprocess_fn
,
img_key
=
args
.
csv_img_key
,
caption_key
=
args
.
csv_caption_key
,
sep
=
args
.
csv_separator
,
tokenizer
=
tokenizer
)
num_samples
=
len
(
dataset
)
sampler
=
DistributedSampler
(
dataset
)
if
args
.
distributed
and
is_train
else
None
shuffle
=
is_train
and
sampler
is
None
dataloader
=
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
shuffle
,
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
sampler
,
drop_last
=
is_train
,
)
dataloader
.
num_samples
=
num_samples
dataloader
.
num_batches
=
len
(
dataloader
)
return
DataInfo
(
dataloader
,
sampler
)
class
SyntheticDataset
(
Dataset
):
def
__init__
(
self
,
transform
=
None
,
image_size
=
(
224
,
224
),
caption
=
"Dummy caption"
,
dataset_size
=
100
,
tokenizer
=
None
):
self
.
transform
=
transform
self
.
image_size
=
image_size
self
.
caption
=
caption
self
.
image
=
Image
.
new
(
'RGB'
,
image_size
)
self
.
dataset_size
=
dataset_size
self
.
preprocess_txt
=
lambda
text
:
tokenizer
(
text
)[
0
]
def
__len__
(
self
):
return
self
.
dataset_size
def
__getitem__
(
self
,
idx
):
if
self
.
transform
is
not
None
:
image
=
self
.
transform
(
self
.
image
)
return
image
,
self
.
preprocess_txt
(
self
.
caption
)
def
get_synthetic_dataset
(
args
,
preprocess_fn
,
is_train
,
epoch
=
0
,
tokenizer
=
None
):
image_size
=
preprocess_fn
.
transforms
[
0
].
size
dataset
=
SyntheticDataset
(
transform
=
preprocess_fn
,
image_size
=
image_size
,
dataset_size
=
args
.
train_num_samples
,
tokenizer
=
tokenizer
)
num_samples
=
len
(
dataset
)
sampler
=
DistributedSampler
(
dataset
)
if
args
.
distributed
and
is_train
else
None
shuffle
=
is_train
and
sampler
is
None
dataloader
=
DataLoader
(
dataset
,
batch_size
=
args
.
batch_size
,
shuffle
=
shuffle
,
num_workers
=
args
.
workers
,
pin_memory
=
True
,
sampler
=
sampler
,
drop_last
=
is_train
,
)
dataloader
.
num_samples
=
num_samples
dataloader
.
num_batches
=
len
(
dataloader
)
return
DataInfo
(
dataloader
,
sampler
)
def
get_dataset_fn
(
data_path
,
dataset_type
):
if
dataset_type
==
"webdataset"
:
return
get_wds_dataset
elif
dataset_type
==
"csv"
:
return
get_csv_dataset
elif
dataset_type
==
"synthetic"
:
return
get_synthetic_dataset
elif
dataset_type
==
"auto"
:
ext
=
data_path
.
split
(
'.'
)[
-
1
]
if
ext
in
[
'csv'
,
'tsv'
]:
return
get_csv_dataset
elif
ext
in
[
'tar'
]:
return
get_wds_dataset
else
:
raise
ValueError
(
f
"Tried to figure out dataset type, but failed for extension
{
ext
}
."
)
else
:
raise
ValueError
(
f
"Unsupported dataset type:
{
dataset_type
}
"
)
def
get_data
(
args
,
preprocess_fns
,
epoch
=
0
,
tokenizer
=
None
):
preprocess_train
,
preprocess_val
=
preprocess_fns
data
=
{}
if
args
.
train_data
or
args
.
dataset_type
==
"synthetic"
:
data
[
"train"
]
=
get_dataset_fn
(
args
.
train_data
,
args
.
dataset_type
)(
args
,
preprocess_train
,
is_train
=
True
,
epoch
=
epoch
,
tokenizer
=
tokenizer
)
if
args
.
val_data
:
data
[
"val"
]
=
get_dataset_fn
(
args
.
val_data
,
args
.
dataset_type
)(
args
,
preprocess_val
,
is_train
=
False
,
tokenizer
=
tokenizer
)
if
args
.
imagenet_val
is
not
None
:
data
[
"imagenet-val"
]
=
get_imagenet
(
args
,
preprocess_fns
,
"val"
)
if
args
.
imagenet_v2
is
not
None
:
data
[
"imagenet-v2"
]
=
get_imagenet
(
args
,
preprocess_fns
,
"v2"
)
return
data
open_clip/src/training/distributed.py
0 → 100644
View file @
f55a786e
import
os
import
torch
import
torch.distributed
as
dist
try
:
import
horovod.torch
as
hvd
except
ImportError
:
hvd
=
None
def
is_global_master
(
args
):
return
args
.
rank
==
0
def
is_local_master
(
args
):
return
args
.
local_rank
==
0
def
is_master
(
args
,
local
=
False
):
return
is_local_master
(
args
)
if
local
else
is_global_master
(
args
)
def
is_using_horovod
():
# NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set
# Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required...
ompi_vars
=
[
"OMPI_COMM_WORLD_RANK"
,
"OMPI_COMM_WORLD_SIZE"
]
pmi_vars
=
[
"PMI_RANK"
,
"PMI_SIZE"
]
if
all
([
var
in
os
.
environ
for
var
in
ompi_vars
])
or
all
([
var
in
os
.
environ
for
var
in
pmi_vars
]):
return
True
else
:
return
False
def
is_using_distributed
():
if
'WORLD_SIZE'
in
os
.
environ
:
return
int
(
os
.
environ
[
'WORLD_SIZE'
])
>
1
if
'SLURM_NTASKS'
in
os
.
environ
:
return
int
(
os
.
environ
[
'SLURM_NTASKS'
])
>
1
return
False
def
world_info_from_env
():
local_rank
=
0
for
v
in
(
'LOCAL_RANK'
,
'MPI_LOCALRANKID'
,
'SLURM_LOCALID'
,
'OMPI_COMM_WORLD_LOCAL_RANK'
):
if
v
in
os
.
environ
:
local_rank
=
int
(
os
.
environ
[
v
])
break
global_rank
=
0
for
v
in
(
'RANK'
,
'PMI_RANK'
,
'SLURM_PROCID'
,
'OMPI_COMM_WORLD_RANK'
):
if
v
in
os
.
environ
:
global_rank
=
int
(
os
.
environ
[
v
])
break
world_size
=
1
for
v
in
(
'WORLD_SIZE'
,
'PMI_SIZE'
,
'SLURM_NTASKS'
,
'OMPI_COMM_WORLD_SIZE'
):
if
v
in
os
.
environ
:
world_size
=
int
(
os
.
environ
[
v
])
break
return
local_rank
,
global_rank
,
world_size
def
init_distributed_device
(
args
):
# Distributed training = training on more than one GPU.
# Works in both single and multi-node scenarios.
args
.
distributed
=
False
args
.
world_size
=
1
args
.
rank
=
0
# global rank
args
.
local_rank
=
0
if
args
.
horovod
:
assert
hvd
is
not
None
,
"Horovod is not installed"
hvd
.
init
()
args
.
local_rank
=
int
(
hvd
.
local_rank
())
args
.
rank
=
hvd
.
rank
()
args
.
world_size
=
hvd
.
size
()
args
.
distributed
=
True
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
os
.
environ
[
'RANK'
]
=
str
(
args
.
rank
)
os
.
environ
[
'WORLD_SIZE'
]
=
str
(
args
.
world_size
)
elif
is_using_distributed
():
if
'SLURM_PROCID'
in
os
.
environ
:
# DDP via SLURM
args
.
local_rank
,
args
.
rank
,
args
.
world_size
=
world_info_from_env
()
# SLURM var -> torch.distributed vars in case needed
os
.
environ
[
'LOCAL_RANK'
]
=
str
(
args
.
local_rank
)
os
.
environ
[
'RANK'
]
=
str
(
args
.
rank
)
os
.
environ
[
'WORLD_SIZE'
]
=
str
(
args
.
world_size
)
torch
.
distributed
.
init_process_group
(
backend
=
args
.
dist_backend
,
init_method
=
args
.
dist_url
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
)
else
:
# DDP via torchrun, torch.distributed.launch
args
.
local_rank
,
_
,
_
=
world_info_from_env
()
torch
.
distributed
.
init_process_group
(
backend
=
args
.
dist_backend
,
init_method
=
args
.
dist_url
)
args
.
world_size
=
torch
.
distributed
.
get_world_size
()
args
.
rank
=
torch
.
distributed
.
get_rank
()
args
.
distributed
=
True
if
torch
.
cuda
.
is_available
():
if
args
.
distributed
and
not
args
.
no_set_device_rank
:
device
=
'cuda:%d'
%
args
.
local_rank
else
:
device
=
'cuda:0'
torch
.
cuda
.
set_device
(
device
)
else
:
device
=
'cpu'
args
.
device
=
device
device
=
torch
.
device
(
device
)
return
device
def
broadcast_object
(
args
,
obj
,
src
=
0
):
# broadcast a pickle-able python object from rank-0 to all ranks
if
args
.
horovod
:
return
hvd
.
broadcast_object
(
obj
,
root_rank
=
src
)
else
:
if
args
.
rank
==
src
:
objects
=
[
obj
]
else
:
objects
=
[
None
]
dist
.
broadcast_object_list
(
objects
,
src
=
src
)
return
objects
[
0
]
def
all_gather_object
(
args
,
obj
,
dst
=
0
):
# gather a pickle-able python object across all ranks
if
args
.
horovod
:
return
hvd
.
allgather_object
(
obj
)
else
:
objects
=
[
None
for
_
in
range
(
args
.
world_size
)]
dist
.
all_gather_object
(
objects
,
obj
)
return
objects
open_clip/src/training/file_utils.py
0 → 100644
View file @
f55a786e
import
logging
import
os
import
multiprocessing
import
subprocess
import
time
import
fsspec
import
torch
from
tqdm
import
tqdm
def
remote_sync_s3
(
local_dir
,
remote_dir
):
# skip epoch_latest which can change during sync.
result
=
subprocess
.
run
([
"aws"
,
"s3"
,
"sync"
,
local_dir
,
remote_dir
,
'--exclude'
,
'*epoch_latest.pt'
],
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
if
result
.
returncode
!=
0
:
logging
.
error
(
f
"Error: Failed to sync with S3 bucket
{
result
.
stderr
.
decode
(
'utf-8'
)
}
"
)
return
False
logging
.
info
(
f
"Successfully synced with S3 bucket"
)
return
True
def
remote_sync_fsspec
(
local_dir
,
remote_dir
):
# FIXME currently this is slow and not recommended. Look into speeding up.
a
=
fsspec
.
get_mapper
(
local_dir
)
b
=
fsspec
.
get_mapper
(
remote_dir
)
for
k
in
a
:
# skip epoch_latest which can change during sync.
if
'epoch_latest.pt'
in
k
:
continue
logging
.
info
(
f
'Attempting to sync
{
k
}
'
)
if
k
in
b
and
len
(
a
[
k
])
==
len
(
b
[
k
]):
logging
.
debug
(
f
'Skipping remote sync for
{
k
}
.'
)
continue
try
:
logging
.
info
(
f
'Successful sync for
{
k
}
.'
)
b
[
k
]
=
a
[
k
]
except
Exception
as
e
:
logging
.
info
(
f
'Error during remote sync for
{
k
}
:
{
e
}
'
)
return
False
return
True
def
remote_sync
(
local_dir
,
remote_dir
,
protocol
):
logging
.
info
(
'Starting remote sync.'
)
if
protocol
==
's3'
:
return
remote_sync_s3
(
local_dir
,
remote_dir
)
elif
protocol
==
'fsspec'
:
return
remote_sync_fsspec
(
local_dir
,
remote_dir
)
else
:
logging
.
error
(
'Remote protocol not known'
)
return
False
def
keep_running_remote_sync
(
sync_every
,
local_dir
,
remote_dir
,
protocol
):
while
True
:
time
.
sleep
(
sync_every
)
remote_sync
(
local_dir
,
remote_dir
,
protocol
)
def
start_sync_process
(
sync_every
,
local_dir
,
remote_dir
,
protocol
):
p
=
multiprocessing
.
Process
(
target
=
keep_running_remote_sync
,
args
=
(
sync_every
,
local_dir
,
remote_dir
,
protocol
))
return
p
# Note: we are not currently using this save function.
def
pt_save
(
pt_obj
,
file_path
):
of
=
fsspec
.
open
(
file_path
,
"wb"
)
with
of
as
f
:
torch
.
save
(
pt_obj
,
file_path
)
def
pt_load
(
file_path
,
map_location
=
None
):
if
not
file_path
.
startswith
(
'/'
):
logging
.
info
(
'Loading remote checkpoint, which may take a bit.'
)
of
=
fsspec
.
open
(
file_path
,
"rb"
)
with
of
as
f
:
out
=
torch
.
load
(
f
,
map_location
=
map_location
)
return
out
def
check_exists
(
file_path
):
try
:
with
fsspec
.
open
(
file_path
):
pass
except
FileNotFoundError
:
return
False
return
True
open_clip/src/training/imagenet_zeroshot_data.py
0 → 100644
View file @
f55a786e
imagenet_classnames
=
[
"tench"
,
"goldfish"
,
"great white shark"
,
"tiger shark"
,
"hammerhead shark"
,
"electric ray"
,
"stingray"
,
"rooster"
,
"hen"
,
"ostrich"
,
"brambling"
,
"goldfinch"
,
"house finch"
,
"junco"
,
"indigo bunting"
,
"American robin"
,
"bulbul"
,
"jay"
,
"magpie"
,
"chickadee"
,
"American dipper"
,
"kite (bird of prey)"
,
"bald eagle"
,
"vulture"
,
"great grey owl"
,
"fire salamander"
,
"smooth newt"
,
"newt"
,
"spotted salamander"
,
"axolotl"
,
"American bullfrog"
,
"tree frog"
,
"tailed frog"
,
"loggerhead sea turtle"
,
"leatherback sea turtle"
,
"mud turtle"
,
"terrapin"
,
"box turtle"
,
"banded gecko"
,
"green iguana"
,
"Carolina anole"
,
"desert grassland whiptail lizard"
,
"agama"
,
"frilled-necked lizard"
,
"alligator lizard"
,
"Gila monster"
,
"European green lizard"
,
"chameleon"
,
"Komodo dragon"
,
"Nile crocodile"
,
"American alligator"
,
"triceratops"
,
"worm snake"
,
"ring-necked snake"
,
"eastern hog-nosed snake"
,
"smooth green snake"
,
"kingsnake"
,
"garter snake"
,
"water snake"
,
"vine snake"
,
"night snake"
,
"boa constrictor"
,
"African rock python"
,
"Indian cobra"
,
"green mamba"
,
"sea snake"
,
"Saharan horned viper"
,
"eastern diamondback rattlesnake"
,
"sidewinder rattlesnake"
,
"trilobite"
,
"harvestman"
,
"scorpion"
,
"yellow garden spider"
,
"barn spider"
,
"European garden spider"
,
"southern black widow"
,
"tarantula"
,
"wolf spider"
,
"tick"
,
"centipede"
,
"black grouse"
,
"ptarmigan"
,
"ruffed grouse"
,
"prairie grouse"
,
"peafowl"
,
"quail"
,
"partridge"
,
"african grey parrot"
,
"macaw"
,
"sulphur-crested cockatoo"
,
"lorikeet"
,
"coucal"
,
"bee eater"
,
"hornbill"
,
"hummingbird"
,
"jacamar"
,
"toucan"
,
"duck"
,
"red-breasted merganser"
,
"goose"
,
"black swan"
,
"tusker"
,
"echidna"
,
"platypus"
,
"wallaby"
,
"koala"
,
"wombat"
,
"jellyfish"
,
"sea anemone"
,
"brain coral"
,
"flatworm"
,
"nematode"
,
"conch"
,
"snail"
,
"slug"
,
"sea slug"
,
"chiton"
,
"chambered nautilus"
,
"Dungeness crab"
,
"rock crab"
,
"fiddler crab"
,
"red king crab"
,
"American lobster"
,
"spiny lobster"
,
"crayfish"
,
"hermit crab"
,
"isopod"
,
"white stork"
,
"black stork"
,
"spoonbill"
,
"flamingo"
,
"little blue heron"
,
"great egret"
,
"bittern bird"
,
"crane bird"
,
"limpkin"
,
"common gallinule"
,
"American coot"
,
"bustard"
,
"ruddy turnstone"
,
"dunlin"
,
"common redshank"
,
"dowitcher"
,
"oystercatcher"
,
"pelican"
,
"king penguin"
,
"albatross"
,
"grey whale"
,
"killer whale"
,
"dugong"
,
"sea lion"
,
"Chihuahua"
,
"Japanese Chin"
,
"Maltese"
,
"Pekingese"
,
"Shih Tzu"
,
"King Charles Spaniel"
,
"Papillon"
,
"toy terrier"
,
"Rhodesian Ridgeback"
,
"Afghan Hound"
,
"Basset Hound"
,
"Beagle"
,
"Bloodhound"
,
"Bluetick Coonhound"
,
"Black and Tan Coonhound"
,
"Treeing Walker Coonhound"
,
"English foxhound"
,
"Redbone Coonhound"
,
"borzoi"
,
"Irish Wolfhound"
,
"Italian Greyhound"
,
"Whippet"
,
"Ibizan Hound"
,
"Norwegian Elkhound"
,
"Otterhound"
,
"Saluki"
,
"Scottish Deerhound"
,
"Weimaraner"
,
"Staffordshire Bull Terrier"
,
"American Staffordshire Terrier"
,
"Bedlington Terrier"
,
"Border Terrier"
,
"Kerry Blue Terrier"
,
"Irish Terrier"
,
"Norfolk Terrier"
,
"Norwich Terrier"
,
"Yorkshire Terrier"
,
"Wire Fox Terrier"
,
"Lakeland Terrier"
,
"Sealyham Terrier"
,
"Airedale Terrier"
,
"Cairn Terrier"
,
"Australian Terrier"
,
"Dandie Dinmont Terrier"
,
"Boston Terrier"
,
"Miniature Schnauzer"
,
"Giant Schnauzer"
,
"Standard Schnauzer"
,
"Scottish Terrier"
,
"Tibetan Terrier"
,
"Australian Silky Terrier"
,
"Soft-coated Wheaten Terrier"
,
"West Highland White Terrier"
,
"Lhasa Apso"
,
"Flat-Coated Retriever"
,
"Curly-coated Retriever"
,
"Golden Retriever"
,
"Labrador Retriever"
,
"Chesapeake Bay Retriever"
,
"German Shorthaired Pointer"
,
"Vizsla"
,
"English Setter"
,
"Irish Setter"
,
"Gordon Setter"
,
"Brittany dog"
,
"Clumber Spaniel"
,
"English Springer Spaniel"
,
"Welsh Springer Spaniel"
,
"Cocker Spaniel"
,
"Sussex Spaniel"
,
"Irish Water Spaniel"
,
"Kuvasz"
,
"Schipperke"
,
"Groenendael dog"
,
"Malinois"
,
"Briard"
,
"Australian Kelpie"
,
"Komondor"
,
"Old English Sheepdog"
,
"Shetland Sheepdog"
,
"collie"
,
"Border Collie"
,
"Bouvier des Flandres dog"
,
"Rottweiler"
,
"German Shepherd Dog"
,
"Dobermann"
,
"Miniature Pinscher"
,
"Greater Swiss Mountain Dog"
,
"Bernese Mountain Dog"
,
"Appenzeller Sennenhund"
,
"Entlebucher Sennenhund"
,
"Boxer"
,
"Bullmastiff"
,
"Tibetan Mastiff"
,
"French Bulldog"
,
"Great Dane"
,
"St. Bernard"
,
"husky"
,
"Alaskan Malamute"
,
"Siberian Husky"
,
"Dalmatian"
,
"Affenpinscher"
,
"Basenji"
,
"pug"
,
"Leonberger"
,
"Newfoundland dog"
,
"Great Pyrenees dog"
,
"Samoyed"
,
"Pomeranian"
,
"Chow Chow"
,
"Keeshond"
,
"brussels griffon"
,
"Pembroke Welsh Corgi"
,
"Cardigan Welsh Corgi"
,
"Toy Poodle"
,
"Miniature Poodle"
,
"Standard Poodle"
,
"Mexican hairless dog (xoloitzcuintli)"
,
"grey wolf"
,
"Alaskan tundra wolf"
,
"red wolf or maned wolf"
,
"coyote"
,
"dingo"
,
"dhole"
,
"African wild dog"
,
"hyena"
,
"red fox"
,
"kit fox"
,
"Arctic fox"
,
"grey fox"
,
"tabby cat"
,
"tiger cat"
,
"Persian cat"
,
"Siamese cat"
,
"Egyptian Mau"
,
"cougar"
,
"lynx"
,
"leopard"
,
"snow leopard"
,
"jaguar"
,
"lion"
,
"tiger"
,
"cheetah"
,
"brown bear"
,
"American black bear"
,
"polar bear"
,
"sloth bear"
,
"mongoose"
,
"meerkat"
,
"tiger beetle"
,
"ladybug"
,
"ground beetle"
,
"longhorn beetle"
,
"leaf beetle"
,
"dung beetle"
,
"rhinoceros beetle"
,
"weevil"
,
"fly"
,
"bee"
,
"ant"
,
"grasshopper"
,
"cricket insect"
,
"stick insect"
,
"cockroach"
,
"praying mantis"
,
"cicada"
,
"leafhopper"
,
"lacewing"
,
"dragonfly"
,
"damselfly"
,
"red admiral butterfly"
,
"ringlet butterfly"
,
"monarch butterfly"
,
"small white butterfly"
,
"sulphur butterfly"
,
"gossamer-winged butterfly"
,
"starfish"
,
"sea urchin"
,
"sea cucumber"
,
"cottontail rabbit"
,
"hare"
,
"Angora rabbit"
,
"hamster"
,
"porcupine"
,
"fox squirrel"
,
"marmot"
,
"beaver"
,
"guinea pig"
,
"common sorrel horse"
,
"zebra"
,
"pig"
,
"wild boar"
,
"warthog"
,
"hippopotamus"
,
"ox"
,
"water buffalo"
,
"bison"
,
"ram (adult male sheep)"
,
"bighorn sheep"
,
"Alpine ibex"
,
"hartebeest"
,
"impala (antelope)"
,
"gazelle"
,
"arabian camel"
,
"llama"
,
"weasel"
,
"mink"
,
"European polecat"
,
"black-footed ferret"
,
"otter"
,
"skunk"
,
"badger"
,
"armadillo"
,
"three-toed sloth"
,
"orangutan"
,
"gorilla"
,
"chimpanzee"
,
"gibbon"
,
"siamang"
,
"guenon"
,
"patas monkey"
,
"baboon"
,
"macaque"
,
"langur"
,
"black-and-white colobus"
,
"proboscis monkey"
,
"marmoset"
,
"white-headed capuchin"
,
"howler monkey"
,
"titi monkey"
,
"Geoffroy's spider monkey"
,
"common squirrel monkey"
,
"ring-tailed lemur"
,
"indri"
,
"Asian elephant"
,
"African bush elephant"
,
"red panda"
,
"giant panda"
,
"snoek fish"
,
"eel"
,
"silver salmon"
,
"rock beauty fish"
,
"clownfish"
,
"sturgeon"
,
"gar fish"
,
"lionfish"
,
"pufferfish"
,
"abacus"
,
"abaya"
,
"academic gown"
,
"accordion"
,
"acoustic guitar"
,
"aircraft carrier"
,
"airliner"
,
"airship"
,
"altar"
,
"ambulance"
,
"amphibious vehicle"
,
"analog clock"
,
"apiary"
,
"apron"
,
"trash can"
,
"assault rifle"
,
"backpack"
,
"bakery"
,
"balance beam"
,
"balloon"
,
"ballpoint pen"
,
"Band-Aid"
,
"banjo"
,
"baluster / handrail"
,
"barbell"
,
"barber chair"
,
"barbershop"
,
"barn"
,
"barometer"
,
"barrel"
,
"wheelbarrow"
,
"baseball"
,
"basketball"
,
"bassinet"
,
"bassoon"
,
"swimming cap"
,
"bath towel"
,
"bathtub"
,
"station wagon"
,
"lighthouse"
,
"beaker"
,
"military hat (bearskin or shako)"
,
"beer bottle"
,
"beer glass"
,
"bell tower"
,
"baby bib"
,
"tandem bicycle"
,
"bikini"
,
"ring binder"
,
"binoculars"
,
"birdhouse"
,
"boathouse"
,
"bobsleigh"
,
"bolo tie"
,
"poke bonnet"
,
"bookcase"
,
"bookstore"
,
"bottle cap"
,
"hunting bow"
,
"bow tie"
,
"brass memorial plaque"
,
"bra"
,
"breakwater"
,
"breastplate"
,
"broom"
,
"bucket"
,
"buckle"
,
"bulletproof vest"
,
"high-speed train"
,
"butcher shop"
,
"taxicab"
,
"cauldron"
,
"candle"
,
"cannon"
,
"canoe"
,
"can opener"
,
"cardigan"
,
"car mirror"
,
"carousel"
,
"tool kit"
,
"cardboard box / carton"
,
"car wheel"
,
"automated teller machine"
,
"cassette"
,
"cassette player"
,
"castle"
,
"catamaran"
,
"CD player"
,
"cello"
,
"mobile phone"
,
"chain"
,
"chain-link fence"
,
"chain mail"
,
"chainsaw"
,
"storage chest"
,
"chiffonier"
,
"bell or wind chime"
,
"china cabinet"
,
"Christmas stocking"
,
"church"
,
"movie theater"
,
"cleaver"
,
"cliff dwelling"
,
"cloak"
,
"clogs"
,
"cocktail shaker"
,
"coffee mug"
,
"coffeemaker"
,
"spiral or coil"
,
"combination lock"
,
"computer keyboard"
,
"candy store"
,
"container ship"
,
"convertible"
,
"corkscrew"
,
"cornet"
,
"cowboy boot"
,
"cowboy hat"
,
"cradle"
,
"construction crane"
,
"crash helmet"
,
"crate"
,
"infant bed"
,
"Crock Pot"
,
"croquet ball"
,
"crutch"
,
"cuirass"
,
"dam"
,
"desk"
,
"desktop computer"
,
"rotary dial telephone"
,
"diaper"
,
"digital clock"
,
"digital watch"
,
"dining table"
,
"dishcloth"
,
"dishwasher"
,
"disc brake"
,
"dock"
,
"dog sled"
,
"dome"
,
"doormat"
,
"drilling rig"
,
"drum"
,
"drumstick"
,
"dumbbell"
,
"Dutch oven"
,
"electric fan"
,
"electric guitar"
,
"electric locomotive"
,
"entertainment center"
,
"envelope"
,
"espresso machine"
,
"face powder"
,
"feather boa"
,
"filing cabinet"
,
"fireboat"
,
"fire truck"
,
"fire screen"
,
"flagpole"
,
"flute"
,
"folding chair"
,
"football helmet"
,
"forklift"
,
"fountain"
,
"fountain pen"
,
"four-poster bed"
,
"freight car"
,
"French horn"
,
"frying pan"
,
"fur coat"
,
"garbage truck"
,
"gas mask or respirator"
,
"gas pump"
,
"goblet"
,
"go-kart"
,
"golf ball"
,
"golf cart"
,
"gondola"
,
"gong"
,
"gown"
,
"grand piano"
,
"greenhouse"
,
"radiator grille"
,
"grocery store"
,
"guillotine"
,
"hair clip"
,
"hair spray"
,
"half-track"
,
"hammer"
,
"hamper"
,
"hair dryer"
,
"hand-held computer"
,
"handkerchief"
,
"hard disk drive"
,
"harmonica"
,
"harp"
,
"combine harvester"
,
"hatchet"
,
"holster"
,
"home theater"
,
"honeycomb"
,
"hook"
,
"hoop skirt"
,
"gymnastic horizontal bar"
,
"horse-drawn vehicle"
,
"hourglass"
,
"iPod"
,
"clothes iron"
,
"carved pumpkin"
,
"jeans"
,
"jeep"
,
"T-shirt"
,
"jigsaw puzzle"
,
"rickshaw"
,
"joystick"
,
"kimono"
,
"knee pad"
,
"knot"
,
"lab coat"
,
"ladle"
,
"lampshade"
,
"laptop computer"
,
"lawn mower"
,
"lens cap"
,
"letter opener"
,
"library"
,
"lifeboat"
,
"lighter"
,
"limousine"
,
"ocean liner"
,
"lipstick"
,
"slip-on shoe"
,
"lotion"
,
"music speaker"
,
"loupe magnifying glass"
,
"sawmill"
,
"magnetic compass"
,
"messenger bag"
,
"mailbox"
,
"tights"
,
"one-piece bathing suit"
,
"manhole cover"
,
"maraca"
,
"marimba"
,
"mask"
,
"matchstick"
,
"maypole"
,
"maze"
,
"measuring cup"
,
"medicine cabinet"
,
"megalith"
,
"microphone"
,
"microwave oven"
,
"military uniform"
,
"milk can"
,
"minibus"
,
"miniskirt"
,
"minivan"
,
"missile"
,
"mitten"
,
"mixing bowl"
,
"mobile home"
,
"ford model t"
,
"modem"
,
"monastery"
,
"monitor"
,
"moped"
,
"mortar and pestle"
,
"graduation cap"
,
"mosque"
,
"mosquito net"
,
"vespa"
,
"mountain bike"
,
"tent"
,
"computer mouse"
,
"mousetrap"
,
"moving van"
,
"muzzle"
,
"metal nail"
,
"neck brace"
,
"necklace"
,
"baby pacifier"
,
"notebook computer"
,
"obelisk"
,
"oboe"
,
"ocarina"
,
"odometer"
,
"oil filter"
,
"pipe organ"
,
"oscilloscope"
,
"overskirt"
,
"bullock cart"
,
"oxygen mask"
,
"product packet / packaging"
,
"paddle"
,
"paddle wheel"
,
"padlock"
,
"paintbrush"
,
"pajamas"
,
"palace"
,
"pan flute"
,
"paper towel"
,
"parachute"
,
"parallel bars"
,
"park bench"
,
"parking meter"
,
"railroad car"
,
"patio"
,
"payphone"
,
"pedestal"
,
"pencil case"
,
"pencil sharpener"
,
"perfume"
,
"Petri dish"
,
"photocopier"
,
"plectrum"
,
"Pickelhaube"
,
"picket fence"
,
"pickup truck"
,
"pier"
,
"piggy bank"
,
"pill bottle"
,
"pillow"
,
"ping-pong ball"
,
"pinwheel"
,
"pirate ship"
,
"drink pitcher"
,
"block plane"
,
"planetarium"
,
"plastic bag"
,
"plate rack"
,
"farm plow"
,
"plunger"
,
"Polaroid camera"
,
"pole"
,
"police van"
,
"poncho"
,
"pool table"
,
"soda bottle"
,
"plant pot"
,
"potter's wheel"
,
"power drill"
,
"prayer rug"
,
"printer"
,
"prison"
,
"missile"
,
"projector"
,
"hockey puck"
,
"punching bag"
,
"purse"
,
"quill"
,
"quilt"
,
"race car"
,
"racket"
,
"radiator"
,
"radio"
,
"radio telescope"
,
"rain barrel"
,
"recreational vehicle"
,
"fishing casting reel"
,
"reflex camera"
,
"refrigerator"
,
"remote control"
,
"restaurant"
,
"revolver"
,
"rifle"
,
"rocking chair"
,
"rotisserie"
,
"eraser"
,
"rugby ball"
,
"ruler measuring stick"
,
"sneaker"
,
"safe"
,
"safety pin"
,
"salt shaker"
,
"sandal"
,
"sarong"
,
"saxophone"
,
"scabbard"
,
"weighing scale"
,
"school bus"
,
"schooner"
,
"scoreboard"
,
"CRT monitor"
,
"screw"
,
"screwdriver"
,
"seat belt"
,
"sewing machine"
,
"shield"
,
"shoe store"
,
"shoji screen / room divider"
,
"shopping basket"
,
"shopping cart"
,
"shovel"
,
"shower cap"
,
"shower curtain"
,
"ski"
,
"balaclava ski mask"
,
"sleeping bag"
,
"slide rule"
,
"sliding door"
,
"slot machine"
,
"snorkel"
,
"snowmobile"
,
"snowplow"
,
"soap dispenser"
,
"soccer ball"
,
"sock"
,
"solar thermal collector"
,
"sombrero"
,
"soup bowl"
,
"keyboard space bar"
,
"space heater"
,
"space shuttle"
,
"spatula"
,
"motorboat"
,
"spider web"
,
"spindle"
,
"sports car"
,
"spotlight"
,
"stage"
,
"steam locomotive"
,
"through arch bridge"
,
"steel drum"
,
"stethoscope"
,
"scarf"
,
"stone wall"
,
"stopwatch"
,
"stove"
,
"strainer"
,
"tram"
,
"stretcher"
,
"couch"
,
"stupa"
,
"submarine"
,
"suit"
,
"sundial"
,
"sunglasses"
,
"sunglasses"
,
"sunscreen"
,
"suspension bridge"
,
"mop"
,
"sweatshirt"
,
"swim trunks / shorts"
,
"swing"
,
"electrical switch"
,
"syringe"
,
"table lamp"
,
"tank"
,
"tape player"
,
"teapot"
,
"teddy bear"
,
"television"
,
"tennis ball"
,
"thatched roof"
,
"front curtain"
,
"thimble"
,
"threshing machine"
,
"throne"
,
"tile roof"
,
"toaster"
,
"tobacco shop"
,
"toilet seat"
,
"torch"
,
"totem pole"
,
"tow truck"
,
"toy store"
,
"tractor"
,
"semi-trailer truck"
,
"tray"
,
"trench coat"
,
"tricycle"
,
"trimaran"
,
"tripod"
,
"triumphal arch"
,
"trolleybus"
,
"trombone"
,
"hot tub"
,
"turnstile"
,
"typewriter keyboard"
,
"umbrella"
,
"unicycle"
,
"upright piano"
,
"vacuum cleaner"
,
"vase"
,
"vaulted or arched ceiling"
,
"velvet fabric"
,
"vending machine"
,
"vestment"
,
"viaduct"
,
"violin"
,
"volleyball"
,
"waffle iron"
,
"wall clock"
,
"wallet"
,
"wardrobe"
,
"military aircraft"
,
"sink"
,
"washing machine"
,
"water bottle"
,
"water jug"
,
"water tower"
,
"whiskey jug"
,
"whistle"
,
"hair wig"
,
"window screen"
,
"window shade"
,
"Windsor tie"
,
"wine bottle"
,
"airplane wing"
,
"wok"
,
"wooden spoon"
,
"wool"
,
"split-rail fence"
,
"shipwreck"
,
"sailboat"
,
"yurt"
,
"website"
,
"comic book"
,
"crossword"
,
"traffic or street sign"
,
"traffic light"
,
"dust jacket"
,
"menu"
,
"plate"
,
"guacamole"
,
"consomme"
,
"hot pot"
,
"trifle"
,
"ice cream"
,
"popsicle"
,
"baguette"
,
"bagel"
,
"pretzel"
,
"cheeseburger"
,
"hot dog"
,
"mashed potatoes"
,
"cabbage"
,
"broccoli"
,
"cauliflower"
,
"zucchini"
,
"spaghetti squash"
,
"acorn squash"
,
"butternut squash"
,
"cucumber"
,
"artichoke"
,
"bell pepper"
,
"cardoon"
,
"mushroom"
,
"Granny Smith apple"
,
"strawberry"
,
"orange"
,
"lemon"
,
"fig"
,
"pineapple"
,
"banana"
,
"jackfruit"
,
"cherimoya (custard apple)"
,
"pomegranate"
,
"hay"
,
"carbonara"
,
"chocolate syrup"
,
"dough"
,
"meatloaf"
,
"pizza"
,
"pot pie"
,
"burrito"
,
"red wine"
,
"espresso"
,
"tea cup"
,
"eggnog"
,
"mountain"
,
"bubble"
,
"cliff"
,
"coral reef"
,
"geyser"
,
"lakeshore"
,
"promontory"
,
"sandbar"
,
"beach"
,
"valley"
,
"volcano"
,
"baseball player"
,
"bridegroom"
,
"scuba diver"
,
"rapeseed"
,
"daisy"
,
"yellow lady's slipper"
,
"corn"
,
"acorn"
,
"rose hip"
,
"horse chestnut seed"
,
"coral fungus"
,
"agaric"
,
"gyromitra"
,
"stinkhorn mushroom"
,
"earth star fungus"
,
"hen of the woods mushroom"
,
"bolete"
,
"corn cob"
,
"toilet paper"
]
openai_imagenet_template
=
[
lambda
c
:
f
'a bad photo of a
{
c
}
.'
,
lambda
c
:
f
'a photo of many
{
c
}
.'
,
lambda
c
:
f
'a sculpture of a
{
c
}
.'
,
lambda
c
:
f
'a photo of the hard to see
{
c
}
.'
,
lambda
c
:
f
'a low resolution photo of the
{
c
}
.'
,
lambda
c
:
f
'a rendering of a
{
c
}
.'
,
lambda
c
:
f
'graffiti of a
{
c
}
.'
,
lambda
c
:
f
'a bad photo of the
{
c
}
.'
,
lambda
c
:
f
'a cropped photo of the
{
c
}
.'
,
lambda
c
:
f
'a tattoo of a
{
c
}
.'
,
lambda
c
:
f
'the embroidered
{
c
}
.'
,
lambda
c
:
f
'a photo of a hard to see
{
c
}
.'
,
lambda
c
:
f
'a bright photo of a
{
c
}
.'
,
lambda
c
:
f
'a photo of a clean
{
c
}
.'
,
lambda
c
:
f
'a photo of a dirty
{
c
}
.'
,
lambda
c
:
f
'a dark photo of the
{
c
}
.'
,
lambda
c
:
f
'a drawing of a
{
c
}
.'
,
lambda
c
:
f
'a photo of my
{
c
}
.'
,
lambda
c
:
f
'the plastic
{
c
}
.'
,
lambda
c
:
f
'a photo of the cool
{
c
}
.'
,
lambda
c
:
f
'a close-up photo of a
{
c
}
.'
,
lambda
c
:
f
'a black and white photo of the
{
c
}
.'
,
lambda
c
:
f
'a painting of the
{
c
}
.'
,
lambda
c
:
f
'a painting of a
{
c
}
.'
,
lambda
c
:
f
'a pixelated photo of the
{
c
}
.'
,
lambda
c
:
f
'a sculpture of the
{
c
}
.'
,
lambda
c
:
f
'a bright photo of the
{
c
}
.'
,
lambda
c
:
f
'a cropped photo of a
{
c
}
.'
,
lambda
c
:
f
'a plastic
{
c
}
.'
,
lambda
c
:
f
'a photo of the dirty
{
c
}
.'
,
lambda
c
:
f
'a jpeg corrupted photo of a
{
c
}
.'
,
lambda
c
:
f
'a blurry photo of the
{
c
}
.'
,
lambda
c
:
f
'a photo of the
{
c
}
.'
,
lambda
c
:
f
'a good photo of the
{
c
}
.'
,
lambda
c
:
f
'a rendering of the
{
c
}
.'
,
lambda
c
:
f
'a
{
c
}
in a video game.'
,
lambda
c
:
f
'a photo of one
{
c
}
.'
,
lambda
c
:
f
'a doodle of a
{
c
}
.'
,
lambda
c
:
f
'a close-up photo of the
{
c
}
.'
,
lambda
c
:
f
'a photo of a
{
c
}
.'
,
lambda
c
:
f
'the origami
{
c
}
.'
,
lambda
c
:
f
'the
{
c
}
in a video game.'
,
lambda
c
:
f
'a sketch of a
{
c
}
.'
,
lambda
c
:
f
'a doodle of the
{
c
}
.'
,
lambda
c
:
f
'a origami
{
c
}
.'
,
lambda
c
:
f
'a low resolution photo of a
{
c
}
.'
,
lambda
c
:
f
'the toy
{
c
}
.'
,
lambda
c
:
f
'a rendition of the
{
c
}
.'
,
lambda
c
:
f
'a photo of the clean
{
c
}
.'
,
lambda
c
:
f
'a photo of a large
{
c
}
.'
,
lambda
c
:
f
'a rendition of a
{
c
}
.'
,
lambda
c
:
f
'a photo of a nice
{
c
}
.'
,
lambda
c
:
f
'a photo of a weird
{
c
}
.'
,
lambda
c
:
f
'a blurry photo of a
{
c
}
.'
,
lambda
c
:
f
'a cartoon
{
c
}
.'
,
lambda
c
:
f
'art of a
{
c
}
.'
,
lambda
c
:
f
'a sketch of the
{
c
}
.'
,
lambda
c
:
f
'a embroidered
{
c
}
.'
,
lambda
c
:
f
'a pixelated photo of a
{
c
}
.'
,
lambda
c
:
f
'itap of the
{
c
}
.'
,
lambda
c
:
f
'a jpeg corrupted photo of the
{
c
}
.'
,
lambda
c
:
f
'a good photo of a
{
c
}
.'
,
lambda
c
:
f
'a plushie
{
c
}
.'
,
lambda
c
:
f
'a photo of the nice
{
c
}
.'
,
lambda
c
:
f
'a photo of the small
{
c
}
.'
,
lambda
c
:
f
'a photo of the weird
{
c
}
.'
,
lambda
c
:
f
'the cartoon
{
c
}
.'
,
lambda
c
:
f
'art of the
{
c
}
.'
,
lambda
c
:
f
'a drawing of the
{
c
}
.'
,
lambda
c
:
f
'a photo of the large
{
c
}
.'
,
lambda
c
:
f
'a black and white photo of a
{
c
}
.'
,
lambda
c
:
f
'the plushie
{
c
}
.'
,
lambda
c
:
f
'a dark photo of a
{
c
}
.'
,
lambda
c
:
f
'itap of a
{
c
}
.'
,
lambda
c
:
f
'graffiti of the
{
c
}
.'
,
lambda
c
:
f
'a toy
{
c
}
.'
,
lambda
c
:
f
'itap of my
{
c
}
.'
,
lambda
c
:
f
'a photo of a cool
{
c
}
.'
,
lambda
c
:
f
'a photo of a small
{
c
}
.'
,
lambda
c
:
f
'a tattoo of the
{
c
}
.'
,
]
open_clip/src/training/logger.py
0 → 100644
View file @
f55a786e
import
logging
def
setup_logging
(
log_file
,
level
,
include_host
=
False
):
if
include_host
:
import
socket
hostname
=
socket
.
gethostname
()
formatter
=
logging
.
Formatter
(
f
'%(asctime)s |
{
hostname
}
| %(levelname)s | %(message)s'
,
datefmt
=
'%Y-%m-%d,%H:%M:%S'
)
else
:
formatter
=
logging
.
Formatter
(
'%(asctime)s | %(levelname)s | %(message)s'
,
datefmt
=
'%Y-%m-%d,%H:%M:%S'
)
logging
.
root
.
setLevel
(
level
)
loggers
=
[
logging
.
getLogger
(
name
)
for
name
in
logging
.
root
.
manager
.
loggerDict
]
for
logger
in
loggers
:
logger
.
setLevel
(
level
)
stream_handler
=
logging
.
StreamHandler
()
stream_handler
.
setFormatter
(
formatter
)
logging
.
root
.
addHandler
(
stream_handler
)
if
log_file
:
file_handler
=
logging
.
FileHandler
(
filename
=
log_file
)
file_handler
.
setFormatter
(
formatter
)
logging
.
root
.
addHandler
(
file_handler
)
open_clip/src/training/main.py
0 → 100644
View file @
f55a786e
import
glob
import
logging
import
os
import
re
import
subprocess
import
sys
import
random
from
datetime
import
datetime
import
numpy
as
np
import
torch
from
torch
import
optim
from
torch.cuda.amp
import
GradScaler
try
:
import
wandb
except
ImportError
:
wandb
=
None
try
:
import
torch.utils.tensorboard
as
tensorboard
except
ImportError
:
tensorboard
=
None
try
:
import
horovod.torch
as
hvd
except
ImportError
:
hvd
=
None
from
open_clip
import
create_model_and_transforms
,
trace_model
,
get_tokenizer
from
training.data
import
get_data
from
training.distributed
import
is_master
,
init_distributed_device
,
broadcast_object
from
training.logger
import
setup_logging
from
training.params
import
parse_args
from
training.scheduler
import
cosine_lr
,
const_lr
,
const_lr_cooldown
from
training.train
import
train_one_epoch
,
evaluate
from
training.file_utils
import
pt_load
,
check_exists
,
start_sync_process
,
remote_sync
LATEST_CHECKPOINT_NAME
=
"epoch_latest.pt"
def
random_seed
(
seed
=
42
,
rank
=
0
):
torch
.
manual_seed
(
seed
+
rank
)
np
.
random
.
seed
(
seed
+
rank
)
random
.
seed
(
seed
+
rank
)
def
natural_key
(
string_
):
"""See http://www.codinghorror.com/blog/archives/001018.html"""
return
[
int
(
s
)
if
s
.
isdigit
()
else
s
for
s
in
re
.
split
(
r
'(\d+)'
,
string_
.
lower
())]
def
get_latest_checkpoint
(
path
:
str
,
remote
:
bool
):
# as writen, this glob recurses, so can pick up checkpoints across multiple sub-folders
if
remote
:
result
=
subprocess
.
run
([
"aws"
,
"s3"
,
"ls"
,
path
+
"/"
],
stdout
=
subprocess
.
PIPE
,
stderr
=
subprocess
.
PIPE
)
print
(
result
)
if
result
.
returncode
==
1
:
return
None
checkpoints
=
[
os
.
path
.
join
(
path
,
x
.
split
(
' '
)[
-
1
])
for
x
in
result
.
stdout
.
decode
().
split
(
'
\n
'
)[:
-
1
]]
else
:
checkpoints
=
glob
.
glob
(
path
+
'**/*.pt'
,
recursive
=
True
)
if
checkpoints
:
checkpoints
=
sorted
(
checkpoints
,
key
=
natural_key
)
return
checkpoints
[
-
1
]
return
None
def
main
(
args
):
args
=
parse_args
(
args
)
if
torch
.
cuda
.
is_available
():
# This enables tf32 on Ampere GPUs which is only 8% slower than
# float16 and almost as accurate as float32
# This was a default in pytorch until 1.12
torch
.
backends
.
cuda
.
matmul
.
allow_tf32
=
True
torch
.
backends
.
cudnn
.
benchmark
=
True
torch
.
backends
.
cudnn
.
deterministic
=
False
# fully initialize distributed device environment
device
=
init_distributed_device
(
args
)
# get the name of the experiments
if
args
.
name
is
None
:
# sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule?
model_name_safe
=
args
.
model
.
replace
(
'/'
,
'-'
)
date_str
=
datetime
.
now
().
strftime
(
"%Y_%m_%d`-%H_%M_%S"
)
if
args
.
distributed
:
# sync date_str from master to all ranks
date_str
=
broadcast_object
(
args
,
date_str
)
args
.
name
=
'-'
.
join
([
date_str
,
f
"model_
{
model_name_safe
}
"
,
f
"lr_
{
args
.
lr
}
"
,
f
"b_
{
args
.
batch_size
}
"
,
f
"j_
{
args
.
workers
}
"
,
f
"p_
{
args
.
precision
}
"
,
])
resume_latest
=
args
.
resume
==
'latest'
log_base_path
=
os
.
path
.
join
(
args
.
logs
,
args
.
name
)
args
.
log_path
=
None
if
is_master
(
args
,
local
=
args
.
log_local
):
os
.
makedirs
(
log_base_path
,
exist_ok
=
True
)
log_filename
=
f
'out-
{
args
.
rank
}
'
if
args
.
log_local
else
'out.log'
args
.
log_path
=
os
.
path
.
join
(
log_base_path
,
log_filename
)
if
os
.
path
.
exists
(
args
.
log_path
)
and
not
resume_latest
:
print
(
"Error. Experiment already exists. Use --name {} to specify a new experiment."
)
return
-
1
# Setup text logger
args
.
log_level
=
logging
.
DEBUG
if
args
.
debug
else
logging
.
INFO
setup_logging
(
args
.
log_path
,
args
.
log_level
)
# Setup wandb, tensorboard, checkpoint logging
args
.
wandb
=
'wandb'
in
args
.
report_to
or
'all'
in
args
.
report_to
args
.
tensorboard
=
'tensorboard'
in
args
.
report_to
or
'all'
in
args
.
report_to
args
.
checkpoint_path
=
os
.
path
.
join
(
log_base_path
,
"checkpoints"
)
if
is_master
(
args
):
args
.
tensorboard_path
=
os
.
path
.
join
(
log_base_path
,
"tensorboard"
)
if
args
.
tensorboard
else
''
for
dirname
in
[
args
.
tensorboard_path
,
args
.
checkpoint_path
]:
if
dirname
:
os
.
makedirs
(
dirname
,
exist_ok
=
True
)
else
:
args
.
tensorboard_path
=
''
if
resume_latest
:
resume_from
=
None
checkpoint_path
=
args
.
checkpoint_path
# If using remote_sync, need to check the remote instead of the local checkpoints folder.
if
args
.
remote_sync
is
not
None
:
checkpoint_path
=
os
.
path
.
join
(
args
.
remote_sync
,
args
.
name
,
"checkpoints"
)
if
args
.
save_most_recent
:
print
(
'Error. Cannot use save-most-recent with remote_sync and resume latest.'
)
return
-
1
if
args
.
remote_sync_protocol
!=
's3'
:
print
(
'Error. Sync protocol not supported when using resume latest.'
)
return
-
1
if
is_master
(
args
):
# Checking for existing checkpoint via master rank only. It is possible for
# different rank processes to see different files if a shared file-system is under
# stress, however it's very difficult to fully work around such situations.
if
args
.
save_most_recent
:
# if --save-most-recent flag is set, look for latest at a fixed filename
resume_from
=
os
.
path
.
join
(
checkpoint_path
,
LATEST_CHECKPOINT_NAME
)
if
not
os
.
path
.
exists
(
resume_from
):
# If no latest checkpoint has been saved yet, don't try to resume
resume_from
=
None
else
:
# otherwise, list checkpoint dir contents and pick the newest checkpoint
resume_from
=
get_latest_checkpoint
(
checkpoint_path
,
remote
=
args
.
remote_sync
is
not
None
)
if
resume_from
:
logging
.
info
(
f
'Found latest resume checkpoint at
{
resume_from
}
.'
)
else
:
logging
.
info
(
f
'No latest resume checkpoint found in
{
checkpoint_path
}
.'
)
if
args
.
distributed
:
# sync found checkpoint path to all ranks
resume_from
=
broadcast_object
(
args
,
resume_from
)
args
.
resume
=
resume_from
if
args
.
copy_codebase
:
copy_codebase
(
args
)
# start the sync proces if remote-sync is not None
remote_sync_process
=
None
if
is_master
(
args
)
and
args
.
remote_sync
is
not
None
:
# first make sure it works
result
=
remote_sync
(
os
.
path
.
join
(
args
.
logs
,
args
.
name
),
os
.
path
.
join
(
args
.
remote_sync
,
args
.
name
),
args
.
remote_sync_protocol
)
if
result
:
logging
.
info
(
'remote sync successful.'
)
else
:
logging
.
info
(
'Error: remote sync failed. Exiting.'
)
return
-
1
# if all looks good, start a process to do this every args.remote_sync_frequency seconds
remote_sync_process
=
start_sync_process
(
args
.
remote_sync_frequency
,
os
.
path
.
join
(
args
.
logs
,
args
.
name
),
os
.
path
.
join
(
args
.
remote_sync
,
args
.
name
),
args
.
remote_sync_protocol
)
remote_sync_process
.
start
()
if
args
.
precision
==
'fp16'
:
logging
.
warning
(
'It is recommended to use AMP mixed-precision instead of FP16. '
'FP16 support needs further verification and tuning, especially for train.'
)
if
args
.
horovod
:
logging
.
info
(
f
'Running in horovod mode with multiple processes / nodes. Device:
{
args
.
device
}
.'
f
'Process (global:
{
args
.
rank
}
, local
{
args
.
local_rank
}
), total
{
args
.
world_size
}
.'
)
elif
args
.
distributed
:
logging
.
info
(
f
'Running in distributed mode with multiple processes. Device:
{
args
.
device
}
.'
f
'Process (global:
{
args
.
rank
}
, local
{
args
.
local_rank
}
), total
{
args
.
world_size
}
.'
)
else
:
logging
.
info
(
f
'Running with a single process. Device
{
args
.
device
}
.'
)
if
isinstance
(
args
.
force_image_size
,
(
tuple
,
list
))
and
len
(
args
.
force_image_size
)
==
1
:
# arg is nargs, single (square) image size list -> int
args
.
force_image_size
=
args
.
force_image_size
[
0
]
random_seed
(
args
.
seed
,
0
)
model
,
preprocess_train
,
preprocess_val
=
create_model_and_transforms
(
args
.
model
,
args
.
pretrained
,
precision
=
args
.
precision
,
device
=
device
,
jit
=
args
.
torchscript
,
force_quick_gelu
=
args
.
force_quick_gelu
,
force_custom_text
=
args
.
force_custom_text
,
force_patch_dropout
=
args
.
force_patch_dropout
,
force_image_size
=
args
.
force_image_size
,
pretrained_image
=
args
.
pretrained_image
,
image_mean
=
args
.
image_mean
,
image_std
=
args
.
image_std
,
aug_cfg
=
args
.
aug_cfg
,
)
random_seed
(
args
.
seed
,
args
.
rank
)
if
args
.
trace
:
model
=
trace_model
(
model
,
batch_size
=
args
.
batch_size
,
device
=
device
)
if
args
.
lock_image
:
# lock image tower as per LiT - https://arxiv.org/abs/2111.07991
model
.
lock_image_tower
(
unlocked_groups
=
args
.
lock_image_unlocked_groups
,
freeze_bn_stats
=
args
.
lock_image_freeze_bn_stats
)
if
args
.
lock_text
:
model
.
lock_text_tower
(
unlocked_layers
=
args
.
lock_text_unlocked_layers
,
freeze_layer_norm
=
args
.
lock_text_freeze_layer_norm
)
if
args
.
grad_checkpointing
:
model
.
set_grad_checkpointing
()
if
is_master
(
args
):
logging
.
info
(
"Model:"
)
logging
.
info
(
f
"
{
str
(
model
)
}
"
)
logging
.
info
(
"Params:"
)
params_file
=
os
.
path
.
join
(
args
.
logs
,
args
.
name
,
"params.txt"
)
with
open
(
params_file
,
"w"
)
as
f
:
for
name
in
sorted
(
vars
(
args
)):
val
=
getattr
(
args
,
name
)
logging
.
info
(
f
"
{
name
}
:
{
val
}
"
)
f
.
write
(
f
"
{
name
}
:
{
val
}
\n
"
)
if
args
.
distributed
and
not
args
.
horovod
:
if
args
.
use_bn_sync
:
model
=
torch
.
nn
.
SyncBatchNorm
.
convert_sync_batchnorm
(
model
)
ddp_args
=
{}
if
args
.
ddp_static_graph
:
# this doesn't exist in older PyTorch, arg only added if enabled
ddp_args
[
'static_graph'
]
=
True
model
=
torch
.
nn
.
parallel
.
DistributedDataParallel
(
model
,
device_ids
=
[
device
],
**
ddp_args
)
# create optimizer and scaler
optimizer
=
None
scaler
=
None
if
args
.
train_data
or
args
.
dataset_type
==
"synthetic"
:
assert
not
args
.
trace
,
'Cannot train with traced model'
exclude
=
lambda
n
,
p
:
p
.
ndim
<
2
or
"bn"
in
n
or
"ln"
in
n
or
"bias"
in
n
or
'logit_scale'
in
n
include
=
lambda
n
,
p
:
not
exclude
(
n
,
p
)
named_parameters
=
list
(
model
.
named_parameters
())
gain_or_bias_params
=
[
p
for
n
,
p
in
named_parameters
if
exclude
(
n
,
p
)
and
p
.
requires_grad
]
rest_params
=
[
p
for
n
,
p
in
named_parameters
if
include
(
n
,
p
)
and
p
.
requires_grad
]
optimizer
=
optim
.
AdamW
(
[
{
"params"
:
gain_or_bias_params
,
"weight_decay"
:
0.
},
{
"params"
:
rest_params
,
"weight_decay"
:
args
.
wd
},
],
lr
=
args
.
lr
,
betas
=
(
args
.
beta1
,
args
.
beta2
),
eps
=
args
.
eps
,
)
if
args
.
horovod
:
optimizer
=
hvd
.
DistributedOptimizer
(
optimizer
,
named_parameters
=
model
.
named_parameters
())
hvd
.
broadcast_parameters
(
model
.
state_dict
(),
root_rank
=
0
)
hvd
.
broadcast_optimizer_state
(
optimizer
,
root_rank
=
0
)
scaler
=
GradScaler
()
if
args
.
precision
==
"amp"
else
None
# optionally resume from a checkpoint
start_epoch
=
0
if
args
.
resume
is
not
None
:
checkpoint
=
pt_load
(
args
.
resume
,
map_location
=
'cpu'
)
if
'epoch'
in
checkpoint
:
# resuming a train checkpoint w/ epoch and optimizer state
start_epoch
=
checkpoint
[
"epoch"
]
sd
=
checkpoint
[
"state_dict"
]
if
not
args
.
distributed
and
next
(
iter
(
sd
.
items
()))[
0
].
startswith
(
'module'
):
sd
=
{
k
[
len
(
'module.'
):]:
v
for
k
,
v
in
sd
.
items
()}
model
.
load_state_dict
(
sd
)
if
optimizer
is
not
None
:
optimizer
.
load_state_dict
(
checkpoint
[
"optimizer"
])
if
scaler
is
not
None
and
'scaler'
in
checkpoint
:
scaler
.
load_state_dict
(
checkpoint
[
'scaler'
])
logging
.
info
(
f
"=> resuming checkpoint '
{
args
.
resume
}
' (epoch
{
start_epoch
}
)"
)
else
:
# loading a bare (model only) checkpoint for fine-tune or evaluation
model
.
load_state_dict
(
checkpoint
)
logging
.
info
(
f
"=> loaded checkpoint '
{
args
.
resume
}
' (epoch
{
start_epoch
}
)"
)
# initialize datasets
data
=
get_data
(
args
,
(
preprocess_train
,
preprocess_val
),
epoch
=
start_epoch
,
tokenizer
=
get_tokenizer
(
args
.
model
))
assert
len
(
data
),
'At least one train or eval dataset must be specified.'
# create scheduler if train
scheduler
=
None
if
'train'
in
data
and
optimizer
is
not
None
:
total_steps
=
(
data
[
"train"
].
dataloader
.
num_batches
//
args
.
accum_freq
)
*
args
.
epochs
if
args
.
lr_scheduler
==
"cosine"
:
scheduler
=
cosine_lr
(
optimizer
,
args
.
lr
,
args
.
warmup
,
total_steps
)
elif
args
.
lr_scheduler
==
"const"
:
scheduler
=
const_lr
(
optimizer
,
args
.
lr
,
args
.
warmup
,
total_steps
)
elif
args
.
lr_scheduler
==
"const-cooldown"
:
assert
args
.
epochs_cooldown
is
not
None
,
\
"Please specify the number of cooldown epochs for this lr schedule."
cooldown_steps
=
(
data
[
"train"
].
dataloader
.
num_batches
//
args
.
accum_freq
)
*
args
.
epochs_cooldown
scheduler
=
const_lr_cooldown
(
optimizer
,
args
.
lr
,
args
.
warmup
,
total_steps
,
cooldown_steps
,
args
.
lr_cooldown_power
,
args
.
lr_cooldown_end
)
else
:
logging
.
error
(
f
'Unknown scheduler,
{
args
.
lr_scheduler
}
. Available options are: cosine, const, const-cooldown.'
)
exit
(
1
)
# determine if this worker should save logs and checkpoints. only do so if it is rank == 0
args
.
save_logs
=
args
.
logs
and
args
.
logs
.
lower
()
!=
'none'
and
is_master
(
args
)
writer
=
None
if
args
.
save_logs
and
args
.
tensorboard
:
assert
tensorboard
is
not
None
,
"Please install tensorboard."
writer
=
tensorboard
.
SummaryWriter
(
args
.
tensorboard_path
)
if
args
.
wandb
and
is_master
(
args
):
assert
wandb
is
not
None
,
'Please install wandb.'
logging
.
debug
(
'Starting wandb.'
)
args
.
train_sz
=
data
[
"train"
].
dataloader
.
num_samples
if
args
.
val_data
is
not
None
:
args
.
val_sz
=
data
[
"val"
].
dataloader
.
num_samples
# you will have to configure this for your project!
wandb
.
init
(
project
=
args
.
wandb_project_name
,
name
=
args
.
name
,
id
=
args
.
name
,
notes
=
args
.
wandb_notes
,
tags
=
[],
resume
=
'auto'
if
args
.
resume
==
"latest"
else
None
,
config
=
vars
(
args
),
)
if
args
.
debug
:
wandb
.
watch
(
model
,
log
=
'all'
)
wandb
.
save
(
params_file
)
logging
.
debug
(
'Finished loading wandb.'
)
if
'train'
not
in
data
:
evaluate
(
model
,
data
,
start_epoch
,
args
,
writer
)
return
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
if
is_master
(
args
):
logging
.
info
(
f
'Start epoch
{
epoch
}
'
)
train_one_epoch
(
model
,
data
,
epoch
,
optimizer
,
scaler
,
scheduler
,
args
,
writer
)
completed_epoch
=
epoch
+
1
if
any
(
v
in
data
for
v
in
(
'val'
,
'imagenet-val'
,
'imagenet-v2'
)):
evaluate
(
model
,
data
,
completed_epoch
,
args
,
writer
)
# Saving checkpoints.
if
args
.
save_logs
:
checkpoint_dict
=
{
"epoch"
:
completed_epoch
,
"name"
:
args
.
name
,
"state_dict"
:
model
.
state_dict
(),
"optimizer"
:
optimizer
.
state_dict
(),
}
if
scaler
is
not
None
:
checkpoint_dict
[
"scaler"
]
=
scaler
.
state_dict
()
if
completed_epoch
==
args
.
epochs
or
(
args
.
save_frequency
>
0
and
(
completed_epoch
%
args
.
save_frequency
)
==
0
):
torch
.
save
(
checkpoint_dict
,
os
.
path
.
join
(
args
.
checkpoint_path
,
f
"epoch_
{
completed_epoch
}
.pt"
),
)
if
args
.
delete_previous_checkpoint
:
previous_checkpoint
=
os
.
path
.
join
(
args
.
checkpoint_path
,
f
"epoch_
{
completed_epoch
-
1
}
.pt"
)
if
os
.
path
.
exists
(
previous_checkpoint
):
os
.
remove
(
previous_checkpoint
)
if
args
.
save_most_recent
:
# try not to corrupt the latest checkpoint if save fails
tmp_save_path
=
os
.
path
.
join
(
args
.
checkpoint_path
,
"tmp.pt"
)
latest_save_path
=
os
.
path
.
join
(
args
.
checkpoint_path
,
LATEST_CHECKPOINT_NAME
)
torch
.
save
(
checkpoint_dict
,
tmp_save_path
)
os
.
replace
(
tmp_save_path
,
latest_save_path
)
if
args
.
wandb
and
is_master
(
args
):
wandb
.
finish
()
# run a final sync.
if
remote_sync_process
is
not
None
:
logging
.
info
(
'Final remote sync.'
)
remote_sync_process
.
terminate
()
result
=
remote_sync
(
os
.
path
.
join
(
args
.
logs
,
args
.
name
),
os
.
path
.
join
(
args
.
remote_sync
,
args
.
name
),
args
.
remote_sync_protocol
)
if
result
:
logging
.
info
(
'Final remote sync successful.'
)
else
:
logging
.
info
(
'Final remote sync failed.'
)
def
copy_codebase
(
args
):
from
shutil
import
copytree
,
ignore_patterns
new_code_path
=
os
.
path
.
join
(
args
.
logs
,
args
.
name
,
"code"
)
if
os
.
path
.
exists
(
new_code_path
):
print
(
f
"Error. Experiment already exists at
{
new_code_path
}
. Use --name to specify a new experiment."
)
return
-
1
print
(
f
"Copying codebase to
{
new_code_path
}
"
)
current_code_path
=
os
.
path
.
realpath
(
__file__
)
for
_
in
range
(
3
):
current_code_path
=
os
.
path
.
dirname
(
current_code_path
)
copytree
(
current_code_path
,
new_code_path
,
ignore
=
ignore_patterns
(
'log'
,
'logs'
,
'wandb'
))
print
(
"Done copying code."
)
return
1
if
__name__
==
"__main__"
:
main
(
sys
.
argv
[
1
:])
open_clip/src/training/params.py
0 → 100644
View file @
f55a786e
import
argparse
import
ast
def
get_default_params
(
model_name
):
# Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
model_name
=
model_name
.
lower
()
if
"vit"
in
model_name
:
return
{
"lr"
:
5.0e-4
,
"beta1"
:
0.9
,
"beta2"
:
0.98
,
"eps"
:
1.0e-6
}
else
:
return
{
"lr"
:
5.0e-4
,
"beta1"
:
0.9
,
"beta2"
:
0.999
,
"eps"
:
1.0e-8
}
class
ParseKwargs
(
argparse
.
Action
):
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
kw
=
{}
for
value
in
values
:
key
,
value
=
value
.
split
(
'='
)
try
:
kw
[
key
]
=
ast
.
literal_eval
(
value
)
except
ValueError
:
kw
[
key
]
=
str
(
value
)
# fallback to string (avoid need to escape on command line)
setattr
(
namespace
,
self
.
dest
,
kw
)
def
parse_args
(
args
):
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--train-data"
,
type
=
str
,
default
=
None
,
help
=
"Path to file(s) with training data"
,
)
parser
.
add_argument
(
"--val-data"
,
type
=
str
,
default
=
None
,
help
=
"Path to file(s) with validation data"
,
)
parser
.
add_argument
(
"--train-num-samples"
,
type
=
int
,
default
=
None
,
help
=
"Number of samples in dataset. Required for webdataset if not available in info file."
,
)
parser
.
add_argument
(
"--val-num-samples"
,
type
=
int
,
default
=
None
,
help
=
"Number of samples in dataset. Useful for webdataset if not available in info file."
,
)
parser
.
add_argument
(
"--dataset-type"
,
choices
=
[
"webdataset"
,
"csv"
,
"synthetic"
,
"auto"
],
default
=
"auto"
,
help
=
"Which type of dataset to process."
)
parser
.
add_argument
(
"--dataset-resampled"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Whether to use sampling with replacement for webdataset shard selection."
)
parser
.
add_argument
(
"--csv-separator"
,
type
=
str
,
default
=
"
\t
"
,
help
=
"For csv-like datasets, which separator to use."
)
parser
.
add_argument
(
"--csv-img-key"
,
type
=
str
,
default
=
"filepath"
,
help
=
"For csv-like datasets, the name of the key for the image paths."
)
parser
.
add_argument
(
"--csv-caption-key"
,
type
=
str
,
default
=
"title"
,
help
=
"For csv-like datasets, the name of the key for the captions."
)
parser
.
add_argument
(
"--imagenet-val"
,
type
=
str
,
default
=
None
,
help
=
"Path to imagenet val set for conducting zero shot evaluation."
,
)
parser
.
add_argument
(
"--imagenet-v2"
,
type
=
str
,
default
=
None
,
help
=
"Path to imagenet v2 for conducting zero shot evaluation."
,
)
parser
.
add_argument
(
"--logs"
,
type
=
str
,
default
=
"./logs/"
,
help
=
"Where to store tensorboard logs. Use None to avoid storing logs."
,
)
parser
.
add_argument
(
"--log-local"
,
action
=
"store_true"
,
default
=
False
,
help
=
"log files on local master, otherwise global master only."
,
)
parser
.
add_argument
(
"--name"
,
type
=
str
,
default
=
None
,
help
=
"Optional identifier for the experiment when storing logs. Otherwise use current time."
,
)
parser
.
add_argument
(
"--workers"
,
type
=
int
,
default
=
1
,
help
=
"Number of dataloader workers per GPU."
)
parser
.
add_argument
(
"--batch-size"
,
type
=
int
,
default
=
64
,
help
=
"Batch size per GPU."
)
parser
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
32
,
help
=
"Number of epochs to train for."
)
parser
.
add_argument
(
"--epochs-cooldown"
,
type
=
int
,
default
=
None
,
help
=
"When scheduler w/ cooldown used, perform cooldown from total_epochs - cooldown_epochs onwards."
)
parser
.
add_argument
(
"--lr"
,
type
=
float
,
default
=
None
,
help
=
"Learning rate."
)
parser
.
add_argument
(
"--beta1"
,
type
=
float
,
default
=
None
,
help
=
"Adam beta 1."
)
parser
.
add_argument
(
"--beta2"
,
type
=
float
,
default
=
None
,
help
=
"Adam beta 2."
)
parser
.
add_argument
(
"--eps"
,
type
=
float
,
default
=
None
,
help
=
"Adam epsilon."
)
parser
.
add_argument
(
"--wd"
,
type
=
float
,
default
=
0.2
,
help
=
"Weight decay."
)
parser
.
add_argument
(
"--warmup"
,
type
=
int
,
default
=
10000
,
help
=
"Number of steps to warmup for."
)
parser
.
add_argument
(
"--use-bn-sync"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Whether to use batch norm sync."
)
parser
.
add_argument
(
"--skip-scheduler"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Use this flag to skip the learning rate decay."
,
)
parser
.
add_argument
(
"--lr-scheduler"
,
type
=
str
,
default
=
'cosine'
,
help
=
"LR scheduler. One of: 'cosine', 'const' (constant), 'const-cooldown' (constant w/ cooldown). Default: cosine"
,
)
parser
.
add_argument
(
"--lr-cooldown-end"
,
type
=
float
,
default
=
0.0
,
help
=
"End learning rate for cooldown schedule. Default: 0"
)
parser
.
add_argument
(
"--lr-cooldown-power"
,
type
=
float
,
default
=
1.0
,
help
=
"Power for polynomial cooldown schedule. Default: 1.0 (linear decay)"
)
parser
.
add_argument
(
"--save-frequency"
,
type
=
int
,
default
=
1
,
help
=
"How often to save checkpoints."
)
parser
.
add_argument
(
"--save-most-recent"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Always save the most recent model trained to epoch_latest.pt."
,
)
parser
.
add_argument
(
"--zeroshot-frequency"
,
type
=
int
,
default
=
2
,
help
=
"How often to run zero shot."
)
parser
.
add_argument
(
"--val-frequency"
,
type
=
int
,
default
=
1
,
help
=
"How often to run evaluation with val data."
)
parser
.
add_argument
(
"--resume"
,
default
=
None
,
type
=
str
,
help
=
"path to latest checkpoint (default: none)"
,
)
parser
.
add_argument
(
"--precision"
,
choices
=
[
"amp"
,
"amp_bf16"
,
"amp_bfloat16"
,
"bf16"
,
"fp16"
,
"fp32"
],
default
=
"amp"
,
help
=
"Floating point precision."
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"RN50"
,
help
=
"Name of the vision backbone to use."
,
)
parser
.
add_argument
(
"--pretrained"
,
default
=
''
,
type
=
str
,
help
=
"Use a pretrained CLIP model weights with the specified tag or file path."
,
)
parser
.
add_argument
(
"--pretrained-image"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Load imagenet pretrained weights for image tower backbone if available."
,
)
parser
.
add_argument
(
"--lock-image"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Lock full image tower by disabling gradients."
,
)
parser
.
add_argument
(
"--lock-image-unlocked-groups"
,
type
=
int
,
default
=
0
,
help
=
"Leave last n image tower layer groups unlocked."
,
)
parser
.
add_argument
(
"--lock-image-freeze-bn-stats"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Freeze BatchNorm running stats in image tower for any locked layers."
,
)
parser
.
add_argument
(
'--image-mean'
,
type
=
float
,
nargs
=
'+'
,
default
=
None
,
metavar
=
'MEAN'
,
help
=
'Override default image mean value of dataset'
)
parser
.
add_argument
(
'--image-std'
,
type
=
float
,
nargs
=
'+'
,
default
=
None
,
metavar
=
'STD'
,
help
=
'Override default image std deviation of of dataset'
)
parser
.
add_argument
(
'--aug-cfg'
,
nargs
=
'*'
,
default
=
{},
action
=
ParseKwargs
)
parser
.
add_argument
(
"--grad-checkpointing"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Enable gradient checkpointing."
,
)
parser
.
add_argument
(
"--local-loss"
,
default
=
False
,
action
=
"store_true"
,
help
=
"calculate loss w/ local features @ global (instead of realizing full global @ global matrix)"
)
parser
.
add_argument
(
"--gather-with-grad"
,
default
=
False
,
action
=
"store_true"
,
help
=
"enable full distributed gradient for feature gather"
)
parser
.
add_argument
(
'--force-image-size'
,
type
=
int
,
nargs
=
'+'
,
default
=
None
,
help
=
'Override default image size'
)
parser
.
add_argument
(
"--force-quick-gelu"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Force use of QuickGELU activation for non-OpenAI transformer models."
,
)
parser
.
add_argument
(
"--force-patch-dropout"
,
default
=
None
,
type
=
float
,
help
=
"Override the patch dropout during training, for fine tuning with no dropout near the end as in the paper"
,
)
parser
.
add_argument
(
"--force-custom-text"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Force use of CustomTextCLIP model (separate text-tower)."
,
)
parser
.
add_argument
(
"--torchscript"
,
default
=
False
,
action
=
'store_true'
,
help
=
"torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'"
,
)
parser
.
add_argument
(
"--trace"
,
default
=
False
,
action
=
'store_true'
,
help
=
"torch.jit.trace the model for inference / eval only"
,
)
parser
.
add_argument
(
"--accum-freq"
,
type
=
int
,
default
=
1
,
help
=
"Update the model every --acum-freq steps."
)
# arguments for distributed training
parser
.
add_argument
(
"--dist-url"
,
default
=
"env://"
,
type
=
str
,
help
=
"url used to set up distributed training"
,
)
parser
.
add_argument
(
"--dist-backend"
,
default
=
"nccl"
,
type
=
str
,
help
=
"distributed backend"
)
parser
.
add_argument
(
"--report-to"
,
default
=
''
,
type
=
str
,
help
=
"Options are ['wandb', 'tensorboard', 'wandb,tensorboard']"
)
parser
.
add_argument
(
"--wandb-notes"
,
default
=
''
,
type
=
str
,
help
=
"Notes if logging with wandb"
)
parser
.
add_argument
(
"--wandb-project-name"
,
type
=
str
,
default
=
'open-clip'
,
help
=
"Name of the project if logging with wandb."
,
)
parser
.
add_argument
(
"--debug"
,
default
=
False
,
action
=
"store_true"
,
help
=
"If true, more information is logged."
)
parser
.
add_argument
(
"--copy-codebase"
,
default
=
False
,
action
=
"store_true"
,
help
=
"If true, we copy the entire base on the log directory, and execute from there."
)
parser
.
add_argument
(
"--horovod"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Use horovod for distributed training."
)
parser
.
add_argument
(
"--ddp-static-graph"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Enable static graph optimization for DDP in PyTorch >= 1.11."
,
)
parser
.
add_argument
(
"--no-set-device-rank"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc)."
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
,
help
=
"Default random seed."
)
parser
.
add_argument
(
"--grad-clip-norm"
,
type
=
float
,
default
=
None
,
help
=
"Gradient clip."
)
parser
.
add_argument
(
"--lock-text"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Lock full text tower by disabling gradients."
,
)
parser
.
add_argument
(
"--lock-text-unlocked-layers"
,
type
=
int
,
default
=
0
,
help
=
"Leave last n image tower layer groups unlocked."
,
)
parser
.
add_argument
(
"--lock-text-freeze-layer-norm"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Freeze BatchNorm running stats in image tower for any locked layers."
,
)
parser
.
add_argument
(
"--log-every-n-steps"
,
type
=
int
,
default
=
100
,
help
=
"Log every n steps to tensorboard/console/wandb."
,
)
parser
.
add_argument
(
"--remote-sync"
,
type
=
str
,
default
=
None
,
help
=
"Optinoally sync with a remote path specified by this arg"
,
)
parser
.
add_argument
(
"--remote-sync-frequency"
,
type
=
int
,
default
=
300
,
help
=
"How frequently to sync to a remote directly if --remote-sync is not None."
,
)
parser
.
add_argument
(
"--remote-sync-protocol"
,
choices
=
[
"s3"
,
"fsspec"
],
default
=
"s3"
,
help
=
"How to do the remote sync backup if --remote-sync is not None."
,
)
parser
.
add_argument
(
"--delete-previous-checkpoint"
,
default
=
False
,
action
=
"store_true"
,
help
=
"If true, delete previous checkpoint after storing a new one."
)
args
=
parser
.
parse_args
(
args
)
# If some params are not passed, we use the default values based on model name.
default_params
=
get_default_params
(
args
.
model
)
for
name
,
val
in
default_params
.
items
():
if
getattr
(
args
,
name
)
is
None
:
setattr
(
args
,
name
,
val
)
return
args
open_clip/src/training/precision.py
0 → 100644
View file @
f55a786e
import
torch
from
contextlib
import
suppress
def
get_autocast
(
precision
):
if
precision
==
'amp'
:
return
torch
.
cuda
.
amp
.
autocast
elif
precision
==
'amp_bfloat16'
or
precision
==
'amp_bf16'
:
# amp_bfloat16 is more stable than amp float16 for clip training
return
lambda
:
torch
.
cuda
.
amp
.
autocast
(
dtype
=
torch
.
bfloat16
)
else
:
return
suppress
open_clip/src/training/profile.py
0 → 100644
View file @
f55a786e
import
argparse
import
torch
import
open_clip
import
pandas
as
pd
from
fvcore.nn
import
FlopCountAnalysis
,
flop_count_str
,
ActivationCountAnalysis
parser
=
argparse
.
ArgumentParser
(
description
=
'OpenCLIP Profiler'
)
# benchmark specific args
parser
.
add_argument
(
'--model'
,
metavar
=
'NAME'
,
default
=
''
,
help
=
'model(s) to profile'
)
parser
.
add_argument
(
'--results-file'
,
default
=
''
,
type
=
str
,
metavar
=
'FILENAME'
,
help
=
'Output csv file for results'
)
def
profile_fvcore
(
model
,
image_input_size
=
(
3
,
224
,
224
),
text_input_size
=
(
77
,),
batch_size
=
1
,
detailed
=
False
,
force_cpu
=
False
):
if
force_cpu
:
model
=
model
.
to
(
'cpu'
)
device
,
dtype
=
next
(
model
.
parameters
()).
device
,
next
(
model
.
parameters
()).
dtype
example_image_input
=
torch
.
ones
((
batch_size
,)
+
image_input_size
,
device
=
device
,
dtype
=
dtype
)
example_text_input
=
torch
.
ones
((
batch_size
,)
+
text_input_size
,
device
=
device
,
dtype
=
torch
.
int64
)
fca
=
FlopCountAnalysis
(
model
,
(
example_image_input
,
example_text_input
))
aca
=
ActivationCountAnalysis
(
model
,
(
example_image_input
,
example_text_input
))
if
detailed
:
fcs
=
flop_count_str
(
fca
)
print
(
fcs
)
return
fca
.
total
(),
aca
.
total
()
def
profile_fvcore_text
(
model
,
text_input_size
=
(
77
,),
batch_size
=
1
,
detailed
=
False
,
force_cpu
=
False
):
if
force_cpu
:
model
=
model
.
to
(
'cpu'
)
device
=
next
(
model
.
parameters
()).
device
example_input
=
torch
.
ones
((
batch_size
,)
+
text_input_size
,
device
=
device
,
dtype
=
torch
.
int64
)
fca
=
FlopCountAnalysis
(
model
,
example_input
)
aca
=
ActivationCountAnalysis
(
model
,
example_input
)
if
detailed
:
fcs
=
flop_count_str
(
fca
)
print
(
fcs
)
return
fca
.
total
(),
aca
.
total
()
def
profile_fvcore_image
(
model
,
image_input_size
=
(
3
,
224
,
224
),
batch_size
=
1
,
detailed
=
False
,
force_cpu
=
False
):
if
force_cpu
:
model
=
model
.
to
(
'cpu'
)
device
,
dtype
=
next
(
model
.
parameters
()).
device
,
next
(
model
.
parameters
()).
dtype
example_input
=
torch
.
ones
((
batch_size
,)
+
image_input_size
,
device
=
device
,
dtype
=
dtype
)
fca
=
FlopCountAnalysis
(
model
,
example_input
)
aca
=
ActivationCountAnalysis
(
model
,
example_input
)
if
detailed
:
fcs
=
flop_count_str
(
fca
)
print
(
fcs
)
return
fca
.
total
(),
aca
.
total
()
def
count_params
(
model
):
return
sum
([
m
.
numel
()
for
m
in
model
.
parameters
()])
def
profile_model
(
model_name
):
model
=
open_clip
.
create_model
(
model_name
,
force_custom_text
=
True
,
pretrained_hf
=
False
)
model
.
eval
()
if
torch
.
cuda
.
is_available
():
model
=
model
.
cuda
()
if
isinstance
(
model
.
visual
.
image_size
,
(
tuple
,
list
)):
image_input_size
=
(
3
,)
+
tuple
(
model
.
visual
.
image_size
[
-
2
:])
else
:
image_input_size
=
(
3
,
model
.
visual
.
image_size
,
model
.
visual
.
image_size
)
text_input_size
=
(
77
,)
results
=
{}
results
[
'model'
]
=
model_name
results
[
'image_size'
]
=
image_input_size
[
1
]
model_cfg
=
open_clip
.
get_model_config
(
model_name
)
if
model_cfg
:
vision_cfg
=
open_clip
.
CLIPVisionCfg
(
**
model_cfg
[
'vision_cfg'
])
text_cfg
=
open_clip
.
CLIPTextCfg
(
**
model_cfg
[
'text_cfg'
])
results
[
'image_width'
]
=
int
(
vision_cfg
.
width
)
results
[
'text_width'
]
=
int
(
text_cfg
.
width
)
results
[
'embed_dim'
]
=
int
(
model_cfg
[
'embed_dim'
])
else
:
results
[
'image_width'
]
=
0
results
[
'text_width'
]
=
0
results
[
'embed_dim'
]
=
0
retries
=
2
while
retries
:
retries
-=
1
try
:
macs
,
acts
=
profile_fvcore
(
model
,
image_input_size
=
image_input_size
,
text_input_size
=
text_input_size
,
force_cpu
=
not
retries
)
image_macs
,
image_acts
=
profile_fvcore_image
(
model
.
visual
,
image_input_size
=
image_input_size
,
force_cpu
=
not
retries
)
text_macs
,
text_acts
=
profile_fvcore_text
(
model
.
text
,
text_input_size
=
text_input_size
,
force_cpu
=
not
retries
)
results
[
'gmacs'
]
=
round
(
macs
/
1e9
,
2
)
results
[
'macts'
]
=
round
(
acts
/
1e6
,
2
)
results
[
'mparams'
]
=
round
(
count_params
(
model
)
/
1e6
,
2
)
results
[
'image_gmacs'
]
=
round
(
image_macs
/
1e9
,
2
)
results
[
'image_macts'
]
=
round
(
image_acts
/
1e6
,
2
)
results
[
'image_mparams'
]
=
round
(
count_params
(
model
.
visual
)
/
1e6
,
2
)
results
[
'text_gmacs'
]
=
round
(
text_macs
/
1e9
,
2
)
results
[
'text_macts'
]
=
round
(
text_acts
/
1e6
,
2
)
results
[
'text_mparams'
]
=
round
(
count_params
(
model
.
text
)
/
1e6
,
2
)
except
RuntimeError
as
e
:
pass
return
results
def
main
():
args
=
parser
.
parse_args
()
# FIXME accept a text file name to allow lists of models in txt/csv
if
args
.
model
==
'all'
:
parsed_model
=
open_clip
.
list_models
()
else
:
parsed_model
=
args
.
model
.
split
(
','
)
results
=
[]
for
m
in
parsed_model
:
row
=
profile_model
(
m
)
results
.
append
(
row
)
df
=
pd
.
DataFrame
(
results
,
columns
=
results
[
0
].
keys
())
df
=
df
.
sort_values
(
'gmacs'
)
print
(
df
)
if
args
.
results_file
:
df
.
to_csv
(
args
.
results_file
,
index
=
False
)
if
__name__
==
'__main__'
:
main
()
open_clip/src/training/scheduler.py
0 → 100644
View file @
f55a786e
import
numpy
as
np
def
assign_learning_rate
(
optimizer
,
new_lr
):
for
param_group
in
optimizer
.
param_groups
:
param_group
[
"lr"
]
=
new_lr
def
_warmup_lr
(
base_lr
,
warmup_length
,
step
):
return
base_lr
*
(
step
+
1
)
/
warmup_length
def
const_lr
(
optimizer
,
base_lr
,
warmup_length
,
steps
):
def
_lr_adjuster
(
step
):
if
step
<
warmup_length
:
lr
=
_warmup_lr
(
base_lr
,
warmup_length
,
step
)
else
:
lr
=
base_lr
assign_learning_rate
(
optimizer
,
lr
)
return
lr
return
_lr_adjuster
def
const_lr_cooldown
(
optimizer
,
base_lr
,
warmup_length
,
steps
,
cooldown_steps
,
cooldown_power
=
1.0
,
cooldown_end_lr
=
0.
):
def
_lr_adjuster
(
step
):
start_cooldown_step
=
steps
-
cooldown_steps
if
step
<
warmup_length
:
lr
=
_warmup_lr
(
base_lr
,
warmup_length
,
step
)
else
:
if
step
<
start_cooldown_step
:
lr
=
base_lr
else
:
e
=
step
-
start_cooldown_step
es
=
steps
-
start_cooldown_step
# linear decay if power == 1; polynomial decay otherwise;
decay
=
(
1
-
(
e
/
es
))
**
cooldown_power
lr
=
decay
*
(
base_lr
-
cooldown_end_lr
)
+
cooldown_end_lr
assign_learning_rate
(
optimizer
,
lr
)
return
lr
return
_lr_adjuster
def
cosine_lr
(
optimizer
,
base_lr
,
warmup_length
,
steps
):
def
_lr_adjuster
(
step
):
if
step
<
warmup_length
:
lr
=
_warmup_lr
(
base_lr
,
warmup_length
,
step
)
else
:
e
=
step
-
warmup_length
es
=
steps
-
warmup_length
lr
=
0.5
*
(
1
+
np
.
cos
(
np
.
pi
*
e
/
es
))
*
base_lr
assign_learning_rate
(
optimizer
,
lr
)
return
lr
return
_lr_adjuster
open_clip/src/training/train.py
0 → 100644
View file @
f55a786e
import
json
import
logging
import
math
import
os
import
time
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
try
:
import
wandb
except
ImportError
:
wandb
=
None
from
open_clip
import
ClipLoss
,
get_cast_dtype
from
.distributed
import
is_master
from
.zero_shot
import
zero_shot_eval
from
.precision
import
get_autocast
class
AverageMeter
(
object
):
"""Computes and stores the average and current value"""
def
__init__
(
self
):
self
.
reset
()
def
reset
(
self
):
self
.
val
=
0
self
.
avg
=
0
self
.
sum
=
0
self
.
count
=
0
def
update
(
self
,
val
,
n
=
1
):
self
.
val
=
val
self
.
sum
+=
val
*
n
self
.
count
+=
n
self
.
avg
=
self
.
sum
/
self
.
count
def
unwrap_model
(
model
):
if
hasattr
(
model
,
'module'
):
return
model
.
module
else
:
return
model
def
backward
(
total_loss
,
scaler
):
if
scaler
is
not
None
:
scaler
.
scale
(
total_loss
).
backward
()
else
:
total_loss
.
backward
()
def
train_one_epoch
(
model
,
data
,
epoch
,
optimizer
,
scaler
,
scheduler
,
args
,
tb_writer
=
None
):
device
=
torch
.
device
(
args
.
device
)
autocast
=
get_autocast
(
args
.
precision
)
cast_dtype
=
get_cast_dtype
(
args
.
precision
)
model
.
train
()
loss
=
ClipLoss
(
local_loss
=
args
.
local_loss
,
gather_with_grad
=
args
.
gather_with_grad
,
cache_labels
=
True
,
rank
=
args
.
rank
,
world_size
=
args
.
world_size
,
use_horovod
=
args
.
horovod
)
data
[
'train'
].
set_epoch
(
epoch
)
# set epoch in process safe manner via sampler or shared_epoch
dataloader
=
data
[
'train'
].
dataloader
num_batches_per_epoch
=
dataloader
.
num_batches
//
args
.
accum_freq
sample_digits
=
math
.
ceil
(
math
.
log
(
dataloader
.
num_samples
+
1
,
10
))
if
args
.
accum_freq
>
1
:
accum_images
,
accum_texts
,
accum_image_features
,
accum_text_features
=
[],
[],
[],
[]
loss_m
=
AverageMeter
()
batch_time_m
=
AverageMeter
()
data_time_m
=
AverageMeter
()
end
=
time
.
time
()
for
i
,
batch
in
enumerate
(
dataloader
):
i_accum
=
i
//
args
.
accum_freq
step
=
num_batches_per_epoch
*
epoch
+
i_accum
if
not
args
.
skip_scheduler
:
scheduler
(
step
)
images
,
texts
=
batch
images
=
images
.
to
(
device
=
device
,
dtype
=
cast_dtype
,
non_blocking
=
True
)
texts
=
texts
.
to
(
device
=
device
,
non_blocking
=
True
)
data_time_m
.
update
(
time
.
time
()
-
end
)
optimizer
.
zero_grad
()
if
args
.
accum_freq
==
1
:
with
autocast
():
image_features
,
text_features
,
logit_scale
=
model
(
images
,
texts
)
total_loss
=
loss
(
image_features
,
text_features
,
logit_scale
)
backward
(
total_loss
,
scaler
)
else
:
# First, cache the features without any gradient tracking.
with
torch
.
no_grad
():
with
autocast
():
chunk_image_features
,
chunk_text_features
,
_
=
model
(
images
,
texts
)
accum_image_features
.
append
(
chunk_image_features
)
accum_text_features
.
append
(
chunk_text_features
)
accum_images
.
append
(
images
)
accum_texts
.
append
(
texts
)
# If (i + 1) % accum_freq is not zero, move on to the next batch.
if
((
i
+
1
)
%
args
.
accum_freq
)
>
0
:
# FIXME this makes data time logging unreliable when accumulating
continue
# Now, ready to take gradients for the last accum_freq batches.
# Re-do the forward pass for those batches, and use the cached features from the other batches as negatives.
# Call backwards each time, but only step optimizer at the end.
optimizer
.
zero_grad
()
for
j
in
range
(
args
.
accum_freq
):
images
=
accum_images
[
j
]
texts
=
accum_texts
[
j
]
with
autocast
():
chunk_image_features
,
chunk_text_features
,
logit_scale
=
model
(
images
,
texts
)
image_features
=
torch
.
cat
(
accum_image_features
[:
j
]
+
[
chunk_image_features
]
+
accum_image_features
[
j
+
1
:])
text_features
=
torch
.
cat
(
accum_text_features
[:
j
]
+
[
chunk_text_features
]
+
accum_text_features
[
j
+
1
:])
total_loss
=
loss
(
image_features
,
text_features
,
logit_scale
)
backward
(
total_loss
,
scaler
)
if
scaler
is
not
None
:
if
args
.
horovod
:
optimizer
.
synchronize
()
scaler
.
unscale_
(
optimizer
)
if
args
.
grad_clip_norm
is
not
None
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
grad_clip_norm
,
norm_type
=
2.0
)
with
optimizer
.
skip_synchronize
():
scaler
.
step
(
optimizer
)
else
:
if
args
.
grad_clip_norm
is
not
None
:
scaler
.
unscale_
(
optimizer
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
grad_clip_norm
,
norm_type
=
2.0
)
scaler
.
step
(
optimizer
)
scaler
.
update
()
else
:
if
args
.
grad_clip_norm
is
not
None
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
grad_clip_norm
,
norm_type
=
2.0
)
optimizer
.
step
()
# reset gradient accum, if enabled
if
args
.
accum_freq
>
1
:
accum_images
,
accum_texts
,
accum_image_features
,
accum_text_features
=
[],
[],
[],
[]
# Note: we clamp to 4.6052 = ln(100), as in the original paper.
with
torch
.
no_grad
():
unwrap_model
(
model
).
logit_scale
.
clamp_
(
0
,
math
.
log
(
100
))
batch_time_m
.
update
(
time
.
time
()
-
end
)
end
=
time
.
time
()
batch_count
=
i_accum
+
1
if
is_master
(
args
)
and
(
i_accum
%
args
.
log_every_n_steps
==
0
or
batch_count
==
num_batches_per_epoch
):
batch_size
=
len
(
images
)
num_samples
=
batch_count
*
batch_size
*
args
.
accum_freq
*
args
.
world_size
samples_per_epoch
=
dataloader
.
num_samples
percent_complete
=
100.0
*
batch_count
/
num_batches_per_epoch
# NOTE loss is coarsely sampled, just master node and per log update
loss_m
.
update
(
total_loss
.
item
(),
batch_size
)
logit_scale_scalar
=
logit_scale
.
item
()
logging
.
info
(
f
"Train Epoch:
{
epoch
}
[
{
num_samples
:
>
{
sample_digits
}}
/
{
samples_per_epoch
}
(
{
percent_complete
:.
0
f
}
%)] "
f
"Loss:
{
loss_m
.
val
:
#.5g
}
(
{
loss_m
.
avg
:
#.4g
}
) "
f
"Data (t):
{
data_time_m
.
avg
:.
3
f
}
"
f
"Batch (t):
{
batch_time_m
.
avg
:.
3
f
}
,
{
args
.
accum_freq
*
args
.
batch_size
*
args
.
world_size
/
batch_time_m
.
val
:
#g
}
/s "
f
"LR:
{
optimizer
.
param_groups
[
0
][
'lr'
]:
5
f
}
"
f
"Logit Scale:
{
logit_scale_scalar
:.
3
f
}
"
)
# Save train loss / etc. Using non avg meter values as loggers have their own smoothing
log_data
=
{
"loss"
:
loss_m
.
val
,
"data_time"
:
data_time_m
.
val
,
"batch_time"
:
batch_time_m
.
val
,
"samples_per_second"
:
args
.
accum_freq
*
args
.
batch_size
*
args
.
world_size
/
batch_time_m
.
val
,
"scale"
:
logit_scale_scalar
,
"lr"
:
optimizer
.
param_groups
[
0
][
"lr"
]
}
for
name
,
val
in
log_data
.
items
():
name
=
"train/"
+
name
if
tb_writer
is
not
None
:
tb_writer
.
add_scalar
(
name
,
val
,
step
)
if
args
.
wandb
:
assert
wandb
is
not
None
,
'Please install wandb.'
wandb
.
log
({
name
:
val
,
'step'
:
step
})
# resetting batch / data time meters per log window
batch_time_m
.
reset
()
data_time_m
.
reset
()
# end for
def
evaluate
(
model
,
data
,
epoch
,
args
,
tb_writer
=
None
):
metrics
=
{}
if
not
is_master
(
args
):
return
metrics
device
=
torch
.
device
(
args
.
device
)
model
.
eval
()
zero_shot_metrics
=
zero_shot_eval
(
model
,
data
,
epoch
,
args
)
metrics
.
update
(
zero_shot_metrics
)
autocast
=
get_autocast
(
args
.
precision
)
cast_dtype
=
get_cast_dtype
(
args
.
precision
)
if
'val'
in
data
and
(
args
.
val_frequency
and
((
epoch
%
args
.
val_frequency
)
==
0
or
epoch
==
args
.
epochs
)):
dataloader
=
data
[
'val'
].
dataloader
num_samples
=
0
samples_per_val
=
dataloader
.
num_samples
# FIXME this does not scale past small eval datasets
# all_image_features @ all_text_features will blow up memory and compute very quickly
cumulative_loss
=
0.0
all_image_features
,
all_text_features
=
[],
[]
with
torch
.
no_grad
():
for
i
,
batch
in
enumerate
(
dataloader
):
images
,
texts
=
batch
images
=
images
.
to
(
device
=
device
,
dtype
=
cast_dtype
,
non_blocking
=
True
)
texts
=
texts
.
to
(
device
=
device
,
non_blocking
=
True
)
with
autocast
():
image_features
,
text_features
,
logit_scale
=
model
(
images
,
texts
)
# features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly
# however, system RAM is easily exceeded and compute time becomes problematic
all_image_features
.
append
(
image_features
.
cpu
())
all_text_features
.
append
(
text_features
.
cpu
())
logit_scale
=
logit_scale
.
mean
()
logits_per_image
=
logit_scale
*
image_features
@
text_features
.
t
()
logits_per_text
=
logits_per_image
.
t
()
batch_size
=
images
.
shape
[
0
]
labels
=
torch
.
arange
(
batch_size
,
device
=
device
).
long
()
total_loss
=
(
F
.
cross_entropy
(
logits_per_image
,
labels
)
+
F
.
cross_entropy
(
logits_per_text
,
labels
)
)
/
2
cumulative_loss
+=
total_loss
*
batch_size
num_samples
+=
batch_size
if
is_master
(
args
)
and
(
i
%
100
)
==
0
:
logging
.
info
(
f
"Eval Epoch:
{
epoch
}
[
{
num_samples
}
/
{
samples_per_val
}
]
\t
"
f
"Loss:
{
cumulative_loss
/
num_samples
:.
6
f
}
\t
"
)
val_metrics
=
get_metrics
(
image_features
=
torch
.
cat
(
all_image_features
),
text_features
=
torch
.
cat
(
all_text_features
),
logit_scale
=
logit_scale
.
cpu
(),
)
loss
=
cumulative_loss
/
num_samples
metrics
.
update
(
{
**
val_metrics
,
"val_loss"
:
loss
.
item
(),
"epoch"
:
epoch
,
"num_samples"
:
num_samples
}
)
if
not
metrics
:
return
metrics
logging
.
info
(
f
"Eval Epoch:
{
epoch
}
"
+
"
\t
"
.
join
([
f
"
{
k
}
:
{
round
(
v
,
4
):.
4
f
}
"
for
k
,
v
in
metrics
.
items
()])
)
if
args
.
save_logs
:
for
name
,
val
in
metrics
.
items
():
if
tb_writer
is
not
None
:
tb_writer
.
add_scalar
(
f
"val/
{
name
}
"
,
val
,
epoch
)
with
open
(
os
.
path
.
join
(
args
.
checkpoint_path
,
"results.jsonl"
),
"a+"
)
as
f
:
f
.
write
(
json
.
dumps
(
metrics
))
f
.
write
(
"
\n
"
)
if
args
.
wandb
:
assert
wandb
is
not
None
,
'Please install wandb.'
for
name
,
val
in
metrics
.
items
():
wandb
.
log
({
f
"val/
{
name
}
"
:
val
,
'epoch'
:
epoch
})
return
metrics
def
get_metrics
(
image_features
,
text_features
,
logit_scale
):
metrics
=
{}
logits_per_image
=
(
logit_scale
*
image_features
@
text_features
.
t
()).
detach
().
cpu
()
logits_per_text
=
logits_per_image
.
t
().
detach
().
cpu
()
logits
=
{
"image_to_text"
:
logits_per_image
,
"text_to_image"
:
logits_per_text
}
ground_truth
=
torch
.
arange
(
len
(
text_features
)).
view
(
-
1
,
1
)
for
name
,
logit
in
logits
.
items
():
ranking
=
torch
.
argsort
(
logit
,
descending
=
True
)
preds
=
torch
.
where
(
ranking
==
ground_truth
)[
1
]
preds
=
preds
.
detach
().
cpu
().
numpy
()
metrics
[
f
"
{
name
}
_mean_rank"
]
=
preds
.
mean
()
+
1
metrics
[
f
"
{
name
}
_median_rank"
]
=
np
.
floor
(
np
.
median
(
preds
))
+
1
for
k
in
[
1
,
5
,
10
]:
metrics
[
f
"
{
name
}
_R@
{
k
}
"
]
=
np
.
mean
(
preds
<
k
)
return
metrics
open_clip/src/training/zero_shot.py
0 → 100644
View file @
f55a786e
import
logging
import
torch
import
torch.nn.functional
as
F
from
tqdm
import
tqdm
from
open_clip
import
get_cast_dtype
,
get_tokenizer
from
.precision
import
get_autocast
from
.imagenet_zeroshot_data
import
imagenet_classnames
,
openai_imagenet_template
def
zero_shot_classifier
(
model
,
classnames
,
templates
,
args
):
tokenizer
=
get_tokenizer
(
args
.
model
)
with
torch
.
no_grad
():
zeroshot_weights
=
[]
for
classname
in
tqdm
(
classnames
):
texts
=
[
template
(
classname
)
for
template
in
templates
]
# format with class
texts
=
tokenizer
(
texts
).
to
(
args
.
device
)
# tokenize
if
args
.
distributed
and
not
args
.
horovod
:
class_embeddings
=
model
.
module
.
encode_text
(
texts
)
else
:
class_embeddings
=
model
.
encode_text
(
texts
)
class_embedding
=
F
.
normalize
(
class_embeddings
,
dim
=-
1
).
mean
(
dim
=
0
)
class_embedding
/=
class_embedding
.
norm
()
zeroshot_weights
.
append
(
class_embedding
)
zeroshot_weights
=
torch
.
stack
(
zeroshot_weights
,
dim
=
1
).
to
(
args
.
device
)
return
zeroshot_weights
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
pred
=
output
.
topk
(
max
(
topk
),
1
,
True
,
True
)[
1
].
t
()
correct
=
pred
.
eq
(
target
.
view
(
1
,
-
1
).
expand_as
(
pred
))
return
[
float
(
correct
[:
k
].
reshape
(
-
1
).
float
().
sum
(
0
,
keepdim
=
True
).
cpu
().
numpy
())
for
k
in
topk
]
def
run
(
model
,
classifier
,
dataloader
,
args
):
autocast
=
get_autocast
(
args
.
precision
)
cast_dtype
=
get_cast_dtype
(
args
.
precision
)
with
torch
.
no_grad
():
top1
,
top5
,
n
=
0.
,
0.
,
0.
for
images
,
target
in
tqdm
(
dataloader
,
unit_scale
=
args
.
batch_size
):
images
=
images
.
to
(
args
.
device
)
if
cast_dtype
is
not
None
:
images
=
images
.
to
(
dtype
=
cast_dtype
)
target
=
target
.
to
(
args
.
device
)
with
autocast
():
# predict
if
args
.
distributed
and
not
args
.
horovod
:
image_features
=
model
.
module
.
encode_image
(
images
)
else
:
image_features
=
model
.
encode_image
(
images
)
image_features
=
F
.
normalize
(
image_features
,
dim
=-
1
)
logits
=
100.
*
image_features
@
classifier
# measure accuracy
acc1
,
acc5
=
accuracy
(
logits
,
target
,
topk
=
(
1
,
5
))
top1
+=
acc1
top5
+=
acc5
n
+=
images
.
size
(
0
)
top1
=
(
top1
/
n
)
top5
=
(
top5
/
n
)
return
top1
,
top5
def
zero_shot_eval
(
model
,
data
,
epoch
,
args
):
if
'imagenet-val'
not
in
data
and
'imagenet-v2'
not
in
data
:
return
{}
if
args
.
zeroshot_frequency
==
0
:
return
{}
if
(
epoch
%
args
.
zeroshot_frequency
)
!=
0
and
epoch
!=
args
.
epochs
:
return
{}
logging
.
info
(
'Starting zero-shot imagenet.'
)
logging
.
info
(
'Building zero-shot classifier'
)
classifier
=
zero_shot_classifier
(
model
,
imagenet_classnames
,
openai_imagenet_template
,
args
)
logging
.
info
(
'Using classifier'
)
results
=
{}
if
'imagenet-val'
in
data
:
top1
,
top5
=
run
(
model
,
classifier
,
data
[
'imagenet-val'
].
dataloader
,
args
)
results
[
'imagenet-zeroshot-val-top1'
]
=
top1
results
[
'imagenet-zeroshot-val-top5'
]
=
top5
if
'imagenet-v2'
in
data
:
top1
,
top5
=
run
(
model
,
classifier
,
data
[
'imagenet-v2'
].
dataloader
,
args
)
results
[
'imagenetv2-zeroshot-val-top1'
]
=
top1
results
[
'imagenetv2-zeroshot-val-top5'
]
=
top5
logging
.
info
(
'Finished zero-shot imagenet.'
)
return
results
open_clip/tests/test_download_pretrained.py
0 → 100644
View file @
f55a786e
import
requests
import
torch
from
PIL
import
Image
import
hashlib
import
tempfile
import
unittest
from
io
import
BytesIO
from
pathlib
import
Path
from
unittest.mock
import
patch
from
urllib3
import
HTTPResponse
from
urllib3._collections
import
HTTPHeaderDict
import
open_clip
from
open_clip.pretrained
import
download_pretrained_from_url
class
DownloadPretrainedTests
(
unittest
.
TestCase
):
def
create_response
(
self
,
data
,
status_code
=
200
,
content_type
=
'application/octet-stream'
):
fp
=
BytesIO
(
data
)
headers
=
HTTPHeaderDict
({
'Content-Type'
:
content_type
,
'Content-Length'
:
str
(
len
(
data
))
})
raw
=
HTTPResponse
(
fp
,
preload_content
=
False
,
headers
=
headers
,
status
=
status_code
)
return
raw
@
patch
(
'open_clip.pretrained.urllib'
)
def
test_download_pretrained_from_url_from_openaipublic
(
self
,
urllib
):
file_contents
=
b
'pretrained model weights'
expected_hash
=
hashlib
.
sha256
(
file_contents
).
hexdigest
()
urllib
.
request
.
urlopen
.
return_value
=
self
.
create_response
(
file_contents
)
with
tempfile
.
TemporaryDirectory
()
as
root
:
url
=
f
'https://openaipublic.azureedge.net/clip/models/
{
expected_hash
}
/RN50.pt'
download_pretrained_from_url
(
url
,
root
)
urllib
.
request
.
urlopen
.
assert_called_once
()
@
patch
(
'open_clip.pretrained.urllib'
)
def
test_download_pretrained_from_url_from_openaipublic_corrupted
(
self
,
urllib
):
file_contents
=
b
'pretrained model weights'
expected_hash
=
hashlib
.
sha256
(
file_contents
).
hexdigest
()
urllib
.
request
.
urlopen
.
return_value
=
self
.
create_response
(
b
'corrupted pretrained model'
)
with
tempfile
.
TemporaryDirectory
()
as
root
:
url
=
f
'https://openaipublic.azureedge.net/clip/models/
{
expected_hash
}
/RN50.pt'
with
self
.
assertRaisesRegex
(
RuntimeError
,
r
'checksum does not not match'
):
download_pretrained_from_url
(
url
,
root
)
urllib
.
request
.
urlopen
.
assert_called_once
()
@
patch
(
'open_clip.pretrained.urllib'
)
def
test_download_pretrained_from_url_from_openaipublic_valid_cache
(
self
,
urllib
):
file_contents
=
b
'pretrained model weights'
expected_hash
=
hashlib
.
sha256
(
file_contents
).
hexdigest
()
urllib
.
request
.
urlopen
.
return_value
=
self
.
create_response
(
file_contents
)
with
tempfile
.
TemporaryDirectory
()
as
root
:
local_file
=
Path
(
root
)
/
'RN50.pt'
local_file
.
write_bytes
(
file_contents
)
url
=
f
'https://openaipublic.azureedge.net/clip/models/
{
expected_hash
}
/RN50.pt'
download_pretrained_from_url
(
url
,
root
)
urllib
.
request
.
urlopen
.
assert_not_called
()
@
patch
(
'open_clip.pretrained.urllib'
)
def
test_download_pretrained_from_url_from_openaipublic_corrupted_cache
(
self
,
urllib
):
file_contents
=
b
'pretrained model weights'
expected_hash
=
hashlib
.
sha256
(
file_contents
).
hexdigest
()
urllib
.
request
.
urlopen
.
return_value
=
self
.
create_response
(
file_contents
)
with
tempfile
.
TemporaryDirectory
()
as
root
:
local_file
=
Path
(
root
)
/
'RN50.pt'
local_file
.
write_bytes
(
b
'corrupted pretrained model'
)
url
=
f
'https://openaipublic.azureedge.net/clip/models/
{
expected_hash
}
/RN50.pt'
download_pretrained_from_url
(
url
,
root
)
urllib
.
request
.
urlopen
.
assert_called_once
()
@
patch
(
'open_clip.pretrained.urllib'
)
def
test_download_pretrained_from_url_from_mlfoundations
(
self
,
urllib
):
file_contents
=
b
'pretrained model weights'
expected_hash
=
hashlib
.
sha256
(
file_contents
).
hexdigest
()[:
8
]
urllib
.
request
.
urlopen
.
return_value
=
self
.
create_response
(
file_contents
)
with
tempfile
.
TemporaryDirectory
()
as
root
:
url
=
f
'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-
{
expected_hash
}
.pt'
download_pretrained_from_url
(
url
,
root
)
urllib
.
request
.
urlopen
.
assert_called_once
()
@
patch
(
'open_clip.pretrained.urllib'
)
def
test_download_pretrained_from_url_from_mlfoundations_corrupted
(
self
,
urllib
):
file_contents
=
b
'pretrained model weights'
expected_hash
=
hashlib
.
sha256
(
file_contents
).
hexdigest
()[:
8
]
urllib
.
request
.
urlopen
.
return_value
=
self
.
create_response
(
b
'corrupted pretrained model'
)
with
tempfile
.
TemporaryDirectory
()
as
root
:
url
=
f
'https://github.com/mlfoundations/download/v0.2-weights/rn50-quickgelu-
{
expected_hash
}
.pt'
with
self
.
assertRaisesRegex
(
RuntimeError
,
r
'checksum does not not match'
):
download_pretrained_from_url
(
url
,
root
)
urllib
.
request
.
urlopen
.
assert_called_once
()
@
patch
(
'open_clip.pretrained.urllib'
)
def
test_download_pretrained_from_hfh
(
self
,
urllib
):
model
,
_
,
preprocess
=
open_clip
.
create_model_and_transforms
(
'hf-hub:hf-internal-testing/tiny-open-clip-model'
)
tokenizer
=
open_clip
.
get_tokenizer
(
'hf-hub:hf-internal-testing/tiny-open-clip-model'
)
img_url
=
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png"
image
=
preprocess
(
Image
.
open
(
requests
.
get
(
img_url
,
stream
=
True
).
raw
)).
unsqueeze
(
0
)
text
=
tokenizer
([
"a diagram"
,
"a dog"
,
"a cat"
])
with
torch
.
no_grad
():
image_features
=
model
.
encode_image
(
image
)
text_features
=
model
.
encode_text
(
text
)
image_features
/=
image_features
.
norm
(
dim
=-
1
,
keepdim
=
True
)
text_features
/=
text_features
.
norm
(
dim
=-
1
,
keepdim
=
True
)
text_probs
=
(
100.0
*
image_features
@
text_features
.
T
).
softmax
(
dim
=-
1
)
self
.
assertTrue
(
torch
.
allclose
(
text_probs
,
torch
.
tensor
([[
0.0597
,
0.6349
,
0.3053
]]),
1e-3
))
open_clip/tests/test_hf_model.py
0 → 100644
View file @
f55a786e
import
pytest
import
torch
from
open_clip.hf_model
import
_POOLERS
,
HFTextEncoder
from
transformers
import
AutoConfig
from
transformers.modeling_outputs
import
BaseModelOutput
# test poolers
def
test_poolers
():
bs
,
sl
,
d
=
2
,
10
,
5
h
=
torch
.
arange
(
sl
).
repeat
(
bs
).
reshape
(
bs
,
sl
)[...,
None
]
*
torch
.
linspace
(
0.2
,
1.
,
d
)
mask
=
torch
.
ones
(
bs
,
sl
,
dtype
=
torch
.
long
)
mask
[:
2
,
6
:]
=
0
x
=
BaseModelOutput
(
h
)
for
name
,
cls
in
_POOLERS
.
items
():
pooler
=
cls
()
res
=
pooler
(
x
,
mask
)
assert
res
.
shape
==
(
bs
,
d
),
f
"
{
name
}
returned wrong shape"
# test HFTextEncoder
@
pytest
.
mark
.
parametrize
(
"model_id"
,
[
"arampacha/roberta-tiny"
,
"roberta-base"
,
"xlm-roberta-base"
,
"google/mt5-base"
])
def
test_pretrained_text_encoder
(
model_id
):
bs
,
sl
,
d
=
2
,
10
,
64
cfg
=
AutoConfig
.
from_pretrained
(
model_id
)
model
=
HFTextEncoder
(
model_id
,
d
,
proj
=
'linear'
)
x
=
torch
.
randint
(
0
,
cfg
.
vocab_size
,
(
bs
,
sl
))
with
torch
.
no_grad
():
emb
=
model
(
x
)
assert
emb
.
shape
==
(
bs
,
d
)
open_clip/tests/test_inference.py
0 → 100644
View file @
f55a786e
import
os
import
pytest
import
torch
import
open_clip
import
util_test
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
''
models_to_test
=
set
(
open_clip
.
list_models
())
# testing excemptions
models_to_test
=
models_to_test
.
difference
({
# not available with timm yet
# see https://github.com/mlfoundations/open_clip/issues/219
'convnext_xlarge'
,
'convnext_xxlarge'
,
'convnext_xxlarge_320'
,
'vit_medium_patch16_gap_256'
,
# exceeds GH runner memory limit
'ViT-bigG-14'
,
'ViT-e-14'
,
'mt5-xl-ViT-H-14'
,
})
if
'OPEN_CLIP_TEST_REG_MODELS'
in
os
.
environ
:
external_model_list
=
os
.
environ
[
'OPEN_CLIP_TEST_REG_MODELS'
]
with
open
(
external_model_list
,
'r'
)
as
f
:
models_to_test
=
set
(
f
.
read
().
splitlines
()).
intersection
(
models_to_test
)
print
(
f
"Selected models from
{
external_model_list
}
:
{
models_to_test
}
"
)
models_to_test
=
list
(
models_to_test
)
models_to_test
.
sort
()
@
pytest
.
mark
.
regression_test
@
pytest
.
mark
.
parametrize
(
'model_name'
,
models_to_test
)
def
test_inference_with_data
(
model_name
,
pretrained
=
None
,
pretrained_hf
=
False
,
precision
=
'fp32'
,
jit
=
False
,
force_quick_gelu
=
False
,
):
util_test
.
seed_all
()
model
,
_
,
preprocess_val
=
open_clip
.
create_model_and_transforms
(
model_name
,
pretrained
=
pretrained
,
precision
=
precision
,
jit
=
jit
,
force_quick_gelu
=
force_quick_gelu
,
pretrained_hf
=
pretrained_hf
)
model_id
=
f
'
{
model_name
}
_
{
pretrained
or
pretrained_hf
}
_
{
precision
}
'
input_dir
,
output_dir
=
util_test
.
get_data_dirs
()
# text
input_text_path
=
os
.
path
.
join
(
input_dir
,
'random_text.pt'
)
gt_text_path
=
os
.
path
.
join
(
output_dir
,
f
'
{
model_id
}
_random_text.pt'
)
if
not
os
.
path
.
isfile
(
input_text_path
):
pytest
.
skip
(
reason
=
f
"missing test data, expected at
{
input_text_path
}
"
)
if
not
os
.
path
.
isfile
(
gt_text_path
):
pytest
.
skip
(
reason
=
f
"missing test data, expected at
{
gt_text_path
}
"
)
input_text
=
torch
.
load
(
input_text_path
)
gt_text
=
torch
.
load
(
gt_text_path
)
y_text
=
util_test
.
inference_text
(
model
,
model_name
,
input_text
)
assert
(
y_text
==
gt_text
).
all
(),
f
"text output differs @
{
input_text_path
}
"
# image
image_size
=
model
.
visual
.
image_size
if
not
isinstance
(
image_size
,
tuple
):
image_size
=
(
image_size
,
image_size
)
input_image_path
=
os
.
path
.
join
(
input_dir
,
f
'random_image_
{
image_size
[
0
]
}
_
{
image_size
[
1
]
}
.pt'
)
gt_image_path
=
os
.
path
.
join
(
output_dir
,
f
'
{
model_id
}
_random_image.pt'
)
if
not
os
.
path
.
isfile
(
input_image_path
):
pytest
.
skip
(
reason
=
f
"missing test data, expected at
{
input_image_path
}
"
)
if
not
os
.
path
.
isfile
(
gt_image_path
):
pytest
.
skip
(
reason
=
f
"missing test data, expected at
{
gt_image_path
}
"
)
input_image
=
torch
.
load
(
input_image_path
)
gt_image
=
torch
.
load
(
gt_image_path
)
y_image
=
util_test
.
inference_image
(
model
,
preprocess_val
,
input_image
)
assert
(
y_image
==
gt_image
).
all
(),
f
"image output differs @
{
input_image_path
}
"
Prev
1
…
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment