Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenych
MML_pytorch
Commits
c501623c
"vscode:/vscode.git/clone" did not exist on "bcaf566038fa3195f0817db1a36c806c70065a20"
Commit
c501623c
authored
Dec 21, 2023
by
chenych
Browse files
add vlmo
parent
4538607b
Changes
57
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2169 additions
and
0 deletions
+2169
-0
vmlo/vlmo/datamodules/vg_caption_datamodule.py
vmlo/vlmo/datamodules/vg_caption_datamodule.py
+15
-0
vmlo/vlmo/datamodules/vqav2_datamodule.py
vmlo/vlmo/datamodules/vqav2_datamodule.py
+36
-0
vmlo/vlmo/datamodules/wikibk_datamodule.py
vmlo/vlmo/datamodules/wikibk_datamodule.py
+15
-0
vmlo/vlmo/datasets/__init__.py
vmlo/vlmo/datasets/__init__.py
+8
-0
vmlo/vlmo/datasets/base_dataset.py
vmlo/vlmo/datasets/base_dataset.py
+227
-0
vmlo/vlmo/datasets/coco_caption_karpathy_dataset.py
vmlo/vlmo/datasets/coco_caption_karpathy_dataset.py
+27
-0
vmlo/vlmo/datasets/conceptual_caption_dataset.py
vmlo/vlmo/datasets/conceptual_caption_dataset.py
+19
-0
vmlo/vlmo/datasets/f30k_caption_karpathy_dataset.py
vmlo/vlmo/datasets/f30k_caption_karpathy_dataset.py
+18
-0
vmlo/vlmo/datasets/nlvr2_dataset.py
vmlo/vlmo/datasets/nlvr2_dataset.py
+51
-0
vmlo/vlmo/datasets/sbu_caption_dataset.py
vmlo/vlmo/datasets/sbu_caption_dataset.py
+19
-0
vmlo/vlmo/datasets/vg_caption_dataset.py
vmlo/vlmo/datasets/vg_caption_dataset.py
+18
-0
vmlo/vlmo/datasets/vqav2_dataset.py
vmlo/vlmo/datasets/vqav2_dataset.py
+47
-0
vmlo/vlmo/datasets/wikibk_dataset.py
vmlo/vlmo/datasets/wikibk_dataset.py
+19
-0
vmlo/vlmo/gadgets/__init__.py
vmlo/vlmo/gadgets/__init__.py
+0
-0
vmlo/vlmo/gadgets/my_metrics.py
vmlo/vlmo/gadgets/my_metrics.py
+69
-0
vmlo/vlmo/modules/__init__.py
vmlo/vlmo/modules/__init__.py
+1
-0
vmlo/vlmo/modules/dist_utils.py
vmlo/vlmo/modules/dist_utils.py
+270
-0
vmlo/vlmo/modules/heads.py
vmlo/vlmo/modules/heads.py
+52
-0
vmlo/vlmo/modules/multiway_transformer.py
vmlo/vlmo/modules/multiway_transformer.py
+407
-0
vmlo/vlmo/modules/objectives.py
vmlo/vlmo/modules/objectives.py
+851
-0
No files found.
vmlo/vlmo/datamodules/vg_caption_datamodule.py
0 → 100644
View file @
c501623c
from
vlmo.datasets
import
VisualGenomeCaptionDataset
from
.datamodule_base
import
BaseDataModule
class
VisualGenomeCaptionDataModule
(
BaseDataModule
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
@
property
def
dataset_cls
(
self
):
return
VisualGenomeCaptionDataset
@
property
def
dataset_name
(
self
):
return
"vg"
vmlo/vlmo/datamodules/vqav2_datamodule.py
0 → 100644
View file @
c501623c
from
vlmo.datasets
import
VQAv2Dataset
from
.datamodule_base
import
BaseDataModule
from
collections
import
defaultdict
class
VQAv2DataModule
(
BaseDataModule
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
@
property
def
dataset_cls
(
self
):
return
VQAv2Dataset
@
property
def
dataset_name
(
self
):
return
"vqa"
def
setup
(
self
,
stage
):
super
().
setup
(
stage
)
train_answers
=
self
.
train_dataset
.
table
[
"answers"
].
to_pandas
().
tolist
()
val_answers
=
self
.
val_dataset
.
table
[
"answers"
].
to_pandas
().
tolist
()
train_labels
=
self
.
train_dataset
.
table
[
"answer_labels"
].
to_pandas
().
tolist
()
val_labels
=
self
.
val_dataset
.
table
[
"answer_labels"
].
to_pandas
().
tolist
()
all_answers
=
[
c
for
c
in
train_answers
+
val_answers
if
c
is
not
None
]
all_answers
=
[
l
for
lll
in
all_answers
for
ll
in
lll
for
l
in
ll
]
all_labels
=
[
c
for
c
in
train_labels
+
val_labels
if
c
is
not
None
]
all_labels
=
[
l
for
lll
in
all_labels
for
ll
in
lll
for
l
in
ll
]
self
.
answer2id
=
{
k
:
v
for
k
,
v
in
zip
(
all_answers
,
all_labels
)}
sorted_a2i
=
sorted
(
self
.
answer2id
.
items
(),
key
=
lambda
x
:
x
[
1
])
self
.
num_class
=
max
(
self
.
answer2id
.
values
())
+
1
self
.
id2answer
=
defaultdict
(
lambda
:
"unknown"
)
for
k
,
v
in
sorted_a2i
:
self
.
id2answer
[
v
]
=
k
vmlo/vlmo/datamodules/wikibk_datamodule.py
0 → 100644
View file @
c501623c
from
vlmo.datasets
import
WikibkDataset
from
.datamodule_base
import
BaseDataModule
class
WikibkDataModule
(
BaseDataModule
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
@
property
def
dataset_cls
(
self
):
return
WikibkDataset
@
property
def
dataset_name
(
self
):
return
"wikibk"
vmlo/vlmo/datasets/__init__.py
0 → 100644
View file @
c501623c
from
.vg_caption_dataset
import
VisualGenomeCaptionDataset
from
.coco_caption_karpathy_dataset
import
CocoCaptionKarpathyDataset
from
.f30k_caption_karpathy_dataset
import
F30KCaptionKarpathyDataset
from
.conceptual_caption_dataset
import
ConceptualCaptionDataset
from
.sbu_caption_dataset
import
SBUCaptionDataset
from
.wikibk_dataset
import
WikibkDataset
from
.vqav2_dataset
import
VQAv2Dataset
from
.nlvr2_dataset
import
NLVR2Dataset
vmlo/vlmo/datasets/base_dataset.py
0 → 100644
View file @
c501623c
import
random
import
torch
import
io
import
pyarrow
as
pa
import
os
from
PIL
import
Image
from
vlmo.transforms
import
keys_to_transforms
class
BaseDataset
(
torch
.
utils
.
data
.
Dataset
):
def
__init__
(
self
,
data_dir
:
str
,
transform_keys
:
list
,
image_size
:
int
,
names
:
list
,
text_column_name
:
str
=
""
,
remove_duplicate
=
False
,
max_text_len
=
40
,
draw_false_image
=
0
,
draw_false_text
=
0
,
image_only
=
False
,
):
"""
data_dir : where dataset file *.arrow lives; existence should be guaranteed via DataModule.prepare_data
transform_keys : keys for generating augmented views of images
text_column_name : pyarrow table column name that has list of strings as elements
"""
assert
len
(
transform_keys
)
>=
1
super
().
__init__
()
self
.
transforms
=
keys_to_transforms
(
transform_keys
,
size
=
image_size
)
self
.
text_column_name
=
text_column_name
self
.
names
=
names
self
.
max_text_len
=
max_text_len
self
.
draw_false_image
=
draw_false_image
self
.
draw_false_text
=
draw_false_text
self
.
image_only
=
image_only
self
.
data_dir
=
data_dir
if
len
(
names
)
!=
0
:
tables
=
[
pa
.
ipc
.
RecordBatchFileReader
(
pa
.
memory_map
(
f
"
{
data_dir
}
/
{
name
}
.arrow"
,
"r"
)
).
read_all
()
for
name
in
names
if
os
.
path
.
isfile
(
f
"
{
data_dir
}
/
{
name
}
.arrow"
)
]
self
.
table_names
=
list
()
for
i
,
name
in
enumerate
(
names
):
self
.
table_names
+=
[
name
]
*
len
(
tables
[
i
])
self
.
table
=
pa
.
concat_tables
(
tables
,
promote
=
True
)
if
text_column_name
!=
""
:
self
.
text_column_name
=
text_column_name
self
.
all_texts
=
self
.
table
[
text_column_name
].
to_pandas
().
tolist
()
self
.
all_texts
=
(
[
list
(
set
(
texts
))
for
texts
in
self
.
all_texts
]
if
remove_duplicate
else
self
.
all_texts
)
else
:
self
.
all_texts
=
list
()
else
:
self
.
all_texts
=
list
()
self
.
index_mapper
=
dict
()
if
text_column_name
!=
""
and
not
self
.
image_only
:
j
=
0
for
i
,
texts
in
enumerate
(
self
.
all_texts
):
for
_j
in
range
(
len
(
texts
)):
self
.
index_mapper
[
j
]
=
(
i
,
_j
)
j
+=
1
else
:
for
i
in
range
(
len
(
self
.
table
)):
self
.
index_mapper
[
i
]
=
(
i
,
None
)
@
property
def
corpus
(
self
):
return
[
text
for
texts
in
self
.
all_texts
for
text
in
texts
]
def
__len__
(
self
):
return
len
(
self
.
index_mapper
)
def
get_raw_image
(
self
,
index
,
image_key
=
"image"
):
index
,
caption_index
=
self
.
index_mapper
[
index
]
image_bytes
=
io
.
BytesIO
(
self
.
table
[
image_key
][
index
].
as_py
())
image_bytes
.
seek
(
0
)
return
Image
.
open
(
image_bytes
).
convert
(
"RGB"
)
def
get_image
(
self
,
index
,
image_key
=
"image"
):
image
=
self
.
get_raw_image
(
index
,
image_key
=
image_key
)
image_tensor
=
[
tr
(
image
)
for
tr
in
self
.
transforms
]
return
{
"image"
:
image_tensor
,
"img_index"
:
self
.
index_mapper
[
index
][
0
],
"cap_index"
:
self
.
index_mapper
[
index
][
1
],
"raw_index"
:
index
,
}
def
get_false_image
(
self
,
rep
,
image_key
=
"image"
):
random_index
=
random
.
randint
(
0
,
len
(
self
.
index_mapper
)
-
1
)
image
=
self
.
get_raw_image
(
random_index
,
image_key
=
image_key
)
image_tensor
=
[
tr
(
image
)
for
tr
in
self
.
transforms
]
return
{
f
"false_image_
{
rep
}
"
:
image_tensor
}
def
get_text
(
self
,
raw_index
):
index
,
caption_index
=
self
.
index_mapper
[
raw_index
]
text
=
self
.
all_texts
[
index
][
caption_index
]
encoding
=
self
.
tokenizer
(
text
,
padding
=
"max_length"
,
truncation
=
True
,
max_length
=
self
.
max_text_len
,
return_special_tokens_mask
=
True
,
)
return
{
"text"
:
(
text
,
encoding
),
"img_index"
:
index
,
"cap_index"
:
caption_index
,
"raw_index"
:
raw_index
,
}
def
get_false_text
(
self
,
rep
):
random_index
=
random
.
randint
(
0
,
len
(
self
.
index_mapper
)
-
1
)
index
,
caption_index
=
self
.
index_mapper
[
random_index
]
text
=
self
.
all_texts
[
index
][
caption_index
]
encoding
=
self
.
tokenizer
(
text
,
truncation
=
True
,
max_length
=
self
.
max_text_len
,
return_special_tokens_mask
=
True
,
)
return
{
f
"false_text_
{
rep
}
"
:
(
text
,
encoding
)}
def
get_suite
(
self
,
index
):
result
=
None
while
result
is
None
:
try
:
ret
=
dict
()
ret
.
update
(
self
.
get_image
(
index
))
if
not
self
.
image_only
:
txt
=
self
.
get_text
(
index
)
ret
.
update
({
"replica"
:
True
if
txt
[
"cap_index"
]
>
0
else
False
})
ret
.
update
(
txt
)
for
i
in
range
(
self
.
draw_false_image
):
ret
.
update
(
self
.
get_false_image
(
i
))
for
i
in
range
(
self
.
draw_false_text
):
ret
.
update
(
self
.
get_false_text
(
i
))
result
=
True
except
Exception
as
e
:
print
(
f
"Error while read file idx
{
index
}
in
{
self
.
names
[
0
]
}
->
{
e
}
"
)
index
=
random
.
randint
(
0
,
len
(
self
.
index_mapper
)
-
1
)
return
ret
def
get_text_suite
(
self
,
index
):
result
=
None
while
result
is
None
:
try
:
ret
=
dict
()
txt
=
self
.
get_text
(
index
)
ret
.
update
({
"replica"
:
True
if
txt
[
"cap_index"
]
>
0
else
False
})
ret
.
update
(
txt
)
result
=
True
except
Exception
as
e
:
print
(
f
"Error while read file idx
{
index
}
in
{
self
.
names
[
0
]
}
->
{
e
}
"
)
index
=
random
.
randint
(
0
,
len
(
self
.
index_mapper
)
-
1
)
return
ret
def
collate
(
self
,
batch
,
mlm_collator
):
batch_size
=
len
(
batch
)
keys
=
set
([
key
for
b
in
batch
for
key
in
b
.
keys
()])
dict_batch
=
{
k
:
[
dic
[
k
]
if
k
in
dic
else
None
for
dic
in
batch
]
for
k
in
keys
}
img_keys
=
[
k
for
k
in
list
(
dict_batch
.
keys
())
if
"image"
in
k
]
for
img_key
in
img_keys
:
new_imgs
=
[
tmp_img
[
0
]
for
tmp_img
in
dict_batch
[
img_key
]]
batch_new_imgs
=
torch
.
stack
(
new_imgs
,
dim
=
0
)
dict_batch
[
img_key
]
=
[
batch_new_imgs
]
txt_keys
=
[
k
for
k
in
list
(
dict_batch
.
keys
())
if
"text"
in
k
]
if
len
(
txt_keys
)
!=
0
:
texts
=
[[
d
[
0
]
for
d
in
dict_batch
[
txt_key
]]
for
txt_key
in
txt_keys
]
encodings
=
[[
d
[
1
]
for
d
in
dict_batch
[
txt_key
]]
for
txt_key
in
txt_keys
]
draw_text_len
=
len
(
encodings
)
flatten_encodings
=
[
e
for
encoding
in
encodings
for
e
in
encoding
]
flatten_mlms
=
mlm_collator
(
flatten_encodings
)
for
i
,
txt_key
in
enumerate
(
txt_keys
):
texts
,
encodings
=
(
[
d
[
0
]
for
d
in
dict_batch
[
txt_key
]],
[
d
[
1
]
for
d
in
dict_batch
[
txt_key
]],
)
mlm_ids
,
mlm_labels
=
(
flatten_mlms
[
"input_ids"
][
batch_size
*
(
i
)
:
batch_size
*
(
i
+
1
)],
flatten_mlms
[
"labels"
][
batch_size
*
(
i
)
:
batch_size
*
(
i
+
1
)],
)
input_ids
=
torch
.
zeros_like
(
mlm_ids
)
attention_mask
=
torch
.
zeros_like
(
mlm_ids
)
for
_i
,
encoding
in
enumerate
(
encodings
):
_input_ids
,
_attention_mask
=
(
torch
.
tensor
(
encoding
[
"input_ids"
]),
torch
.
tensor
(
encoding
[
"attention_mask"
]),
)
input_ids
[
_i
,
:
len
(
_input_ids
)]
=
_input_ids
attention_mask
[
_i
,
:
len
(
_attention_mask
)]
=
_attention_mask
dict_batch
[
txt_key
]
=
texts
dict_batch
[
f
"
{
txt_key
}
_ids"
]
=
input_ids
dict_batch
[
f
"
{
txt_key
}
_labels"
]
=
torch
.
full_like
(
input_ids
,
-
100
)
dict_batch
[
f
"
{
txt_key
}
_ids_mlm"
]
=
mlm_ids
dict_batch
[
f
"
{
txt_key
}
_labels_mlm"
]
=
mlm_labels
dict_batch
[
f
"
{
txt_key
}
_masks"
]
=
attention_mask
return
dict_batch
vmlo/vlmo/datasets/coco_caption_karpathy_dataset.py
0 → 100644
View file @
c501623c
from
.base_dataset
import
BaseDataset
class
CocoCaptionKarpathyDataset
(
BaseDataset
):
def
__init__
(
self
,
*
args
,
split
=
""
,
**
kwargs
):
assert
split
in
[
"train"
,
"val"
,
"test"
]
self
.
split
=
split
if
split
==
"train"
:
names
=
[
"coco_caption_karpathy_train"
,
"coco_caption_karpathy_restval"
]
elif
split
==
"val"
:
names
=
[
"coco_caption_karpathy_val"
]
elif
split
==
"test"
:
names
=
[
"coco_caption_karpathy_test"
]
super
().
__init__
(
*
args
,
**
kwargs
,
names
=
names
,
text_column_name
=
"caption"
)
def
__getitem__
(
self
,
index
):
suite
=
self
.
get_suite
(
index
)
if
"test"
in
self
.
split
:
_index
,
_question_index
=
self
.
index_mapper
[
index
]
iid
=
self
.
table
[
"image_id"
][
_index
].
as_py
()
iid
=
int
(
iid
.
split
(
"."
)[
0
].
split
(
"_"
)[
-
1
])
suite
.
update
({
"iid"
:
iid
})
return
suite
vmlo/vlmo/datasets/conceptual_caption_dataset.py
0 → 100644
View file @
c501623c
from
glob
import
glob
from
.base_dataset
import
BaseDataset
class
ConceptualCaptionDataset
(
BaseDataset
):
def
__init__
(
self
,
*
args
,
split
=
""
,
**
kwargs
):
assert
split
in
[
"train"
,
"val"
,
"test"
]
if
split
==
"test"
:
split
=
"val"
if
split
==
"train"
:
names
=
[
f
"conceptual_caption_train_
{
i
}
"
for
i
in
range
(
30
)]
elif
split
==
"val"
:
names
=
[
"conceptual_caption_val_0"
]
super
().
__init__
(
*
args
,
**
kwargs
,
names
=
names
,
text_column_name
=
"caption"
)
def
__getitem__
(
self
,
index
):
return
self
.
get_suite
(
index
)
vmlo/vlmo/datasets/f30k_caption_karpathy_dataset.py
0 → 100644
View file @
c501623c
from
.base_dataset
import
BaseDataset
class
F30KCaptionKarpathyDataset
(
BaseDataset
):
def
__init__
(
self
,
*
args
,
split
=
""
,
**
kwargs
):
assert
split
in
[
"train"
,
"val"
,
"test"
]
if
split
==
"train"
:
names
=
[
"f30k_caption_karpathy_train"
]
elif
split
==
"val"
:
names
=
[
"f30k_caption_karpathy_val"
]
elif
split
==
"test"
:
names
=
[
"f30k_caption_karpathy_test"
]
super
().
__init__
(
*
args
,
**
kwargs
,
names
=
names
,
text_column_name
=
"caption"
)
def
__getitem__
(
self
,
index
):
return
self
.
get_suite
(
index
)
vmlo/vlmo/datasets/nlvr2_dataset.py
0 → 100644
View file @
c501623c
from
.base_dataset
import
BaseDataset
import
sys
import
random
class
NLVR2Dataset
(
BaseDataset
):
def
__init__
(
self
,
*
args
,
split
=
""
,
**
kwargs
):
assert
split
in
[
"train"
,
"val"
,
"test"
]
self
.
split
=
split
if
split
==
"train"
:
names
=
[
"nlvr2_train"
]
elif
split
==
"val"
:
names
=
[
"nlvr2_dev"
,
"nlvr2_test1"
]
elif
split
==
"test"
:
names
=
[
"nlvr2_dev"
,
"nlvr2_test1"
]
super
().
__init__
(
*
args
,
**
kwargs
,
names
=
names
,
text_column_name
=
"questions"
,
remove_duplicate
=
False
,
)
def
__getitem__
(
self
,
index
):
result
=
None
while
result
is
None
:
try
:
image_tensor_0
=
self
.
get_image
(
index
,
image_key
=
"image_0"
)[
"image"
]
image_tensor_1
=
self
.
get_image
(
index
,
image_key
=
"image_1"
)[
"image"
]
text
=
self
.
get_text
(
index
)[
"text"
]
result
=
True
except
:
print
(
f
"error while read file idx
{
index
}
in
{
self
.
names
[
0
]
}
"
,
file
=
sys
.
stderr
,
)
index
=
random
.
randint
(
0
,
len
(
self
.
index_mapper
)
-
1
)
index
,
question_index
=
self
.
index_mapper
[
index
]
answers
=
self
.
table
[
"answers"
][
index
][
question_index
].
as_py
()
answers
=
answers
==
"True"
return
{
"image_0"
:
image_tensor_0
,
"image_1"
:
image_tensor_1
,
"text"
:
text
,
"answers"
:
answers
,
"table_name"
:
self
.
table_names
[
index
],
}
vmlo/vlmo/datasets/sbu_caption_dataset.py
0 → 100644
View file @
c501623c
from
glob
import
glob
from
.base_dataset
import
BaseDataset
class
SBUCaptionDataset
(
BaseDataset
):
def
__init__
(
self
,
*
args
,
split
=
""
,
**
kwargs
):
assert
split
in
[
"train"
,
"val"
,
"test"
]
if
split
==
"test"
:
split
=
"val"
if
split
==
"train"
:
names
=
[
f
"sbu_
{
i
}
"
for
i
in
range
(
9
)]
elif
split
==
"val"
:
names
=
[]
super
().
__init__
(
*
args
,
**
kwargs
,
names
=
names
,
text_column_name
=
"caption"
)
def
__getitem__
(
self
,
index
):
return
self
.
get_suite
(
index
)
vmlo/vlmo/datasets/vg_caption_dataset.py
0 → 100644
View file @
c501623c
from
.base_dataset
import
BaseDataset
class
VisualGenomeCaptionDataset
(
BaseDataset
):
def
__init__
(
self
,
*
args
,
split
=
""
,
**
kwargs
):
assert
split
in
[
"train"
,
"val"
,
"test"
]
if
split
==
"test"
:
split
=
"val"
if
split
==
"train"
:
names
=
[
"vg"
]
elif
split
==
"val"
:
names
=
[]
super
().
__init__
(
*
args
,
**
kwargs
,
names
=
names
,
text_column_name
=
"caption"
)
def
__getitem__
(
self
,
index
):
return
self
.
get_suite
(
index
)
vmlo/vlmo/datasets/vqav2_dataset.py
0 → 100644
View file @
c501623c
from
.base_dataset
import
BaseDataset
class
VQAv2Dataset
(
BaseDataset
):
def
__init__
(
self
,
*
args
,
split
=
""
,
**
kwargs
):
assert
split
in
[
"train"
,
"val"
,
"test"
]
self
.
split
=
split
if
split
==
"train"
:
names
=
[
"vqav2_train"
,
"vqav2_trainable_val"
]
elif
split
==
"val"
:
names
=
[
"vqav2_rest_val"
]
elif
split
==
"test"
:
names
=
[
"vqav2_test"
]
# vqav2_test-dev for test-dev
super
().
__init__
(
*
args
,
**
kwargs
,
names
=
names
,
text_column_name
=
"questions"
,
remove_duplicate
=
False
,
)
def
__getitem__
(
self
,
index
):
image_tensor
=
self
.
get_image
(
index
)[
"image"
]
text
=
self
.
get_text
(
index
)[
"text"
]
index
,
question_index
=
self
.
index_mapper
[
index
]
qid
=
self
.
table
[
"question_id"
][
index
][
question_index
].
as_py
()
if
self
.
split
!=
"test"
:
answers
=
self
.
table
[
"answers"
][
index
][
question_index
].
as_py
()
labels
=
self
.
table
[
"answer_labels"
][
index
][
question_index
].
as_py
()
scores
=
self
.
table
[
"answer_scores"
][
index
][
question_index
].
as_py
()
else
:
answers
=
list
()
labels
=
list
()
scores
=
list
()
return
{
"image"
:
image_tensor
,
"text"
:
text
,
"vqa_answer"
:
answers
,
"vqa_labels"
:
labels
,
"vqa_scores"
:
scores
,
"qid"
:
qid
,
}
vmlo/vlmo/datasets/wikibk_dataset.py
0 → 100644
View file @
c501623c
from
glob
import
glob
from
.base_dataset
import
BaseDataset
class
WikibkDataset
(
BaseDataset
):
def
__init__
(
self
,
*
args
,
split
=
""
,
**
kwargs
):
assert
split
in
[
"train"
,
"val"
,
"test"
]
if
split
==
"test"
:
split
=
"val"
if
split
==
"train"
:
names
=
[
f
"wikibk_train_
{
i
}
"
for
i
in
range
(
50
)]
elif
split
==
"val"
:
names
=
[
"wikibk_val_0"
]
super
().
__init__
(
*
args
,
**
kwargs
,
names
=
names
,
text_column_name
=
"caption"
)
def
__getitem__
(
self
,
index
):
return
self
.
get_text_suite
(
index
)
vmlo/vlmo/gadgets/__init__.py
0 → 100644
View file @
c501623c
vmlo/vlmo/gadgets/my_metrics.py
0 → 100644
View file @
c501623c
import
torch
from
torchmetrics
import
Metric
class
Accuracy
(
Metric
):
def
__init__
(
self
,
dist_sync_on_step
=
False
):
super
().
__init__
(
dist_sync_on_step
=
dist_sync_on_step
)
self
.
add_state
(
"correct"
,
default
=
torch
.
tensor
(
0.0
),
dist_reduce_fx
=
"sum"
)
self
.
add_state
(
"total"
,
default
=
torch
.
tensor
(
0.0
),
dist_reduce_fx
=
"sum"
)
def
update
(
self
,
logits
,
target
):
logits
,
target
=
(
logits
.
detach
().
to
(
self
.
correct
.
device
),
target
.
detach
().
to
(
self
.
correct
.
device
),
)
preds
=
logits
.
argmax
(
dim
=-
1
)
preds
=
preds
[
target
!=
-
100
]
target
=
target
[
target
!=
-
100
]
if
target
.
numel
()
==
0
:
return
1
assert
preds
.
shape
==
target
.
shape
self
.
correct
+=
torch
.
sum
(
preds
==
target
)
self
.
total
+=
target
.
numel
()
def
compute
(
self
):
return
self
.
correct
/
self
.
total
class
Scalar
(
Metric
):
def
__init__
(
self
,
dist_sync_on_step
=
False
):
super
().
__init__
(
dist_sync_on_step
=
dist_sync_on_step
)
self
.
add_state
(
"scalar"
,
default
=
torch
.
tensor
(
0.0
),
dist_reduce_fx
=
"sum"
)
self
.
add_state
(
"total"
,
default
=
torch
.
tensor
(
0.0
),
dist_reduce_fx
=
"sum"
)
def
update
(
self
,
scalar
):
if
isinstance
(
scalar
,
torch
.
Tensor
):
scalar
=
scalar
.
detach
().
to
(
self
.
scalar
.
device
)
else
:
scalar
=
torch
.
tensor
(
scalar
).
float
().
to
(
self
.
scalar
.
device
)
self
.
scalar
+=
scalar
self
.
total
+=
1
def
compute
(
self
):
return
self
.
scalar
/
self
.
total
class
VQAScore
(
Metric
):
def
__init__
(
self
,
dist_sync_on_step
=
False
):
super
().
__init__
(
dist_sync_on_step
=
dist_sync_on_step
)
self
.
add_state
(
"score"
,
default
=
torch
.
tensor
(
0.0
),
dist_reduce_fx
=
"sum"
)
self
.
add_state
(
"total"
,
default
=
torch
.
tensor
(
0.0
),
dist_reduce_fx
=
"sum"
)
def
update
(
self
,
logits
,
target
):
logits
,
target
=
(
logits
.
detach
().
float
().
to
(
self
.
score
.
device
),
target
.
detach
().
float
().
to
(
self
.
score
.
device
),
)
logits
=
torch
.
max
(
logits
,
1
)[
1
]
one_hots
=
torch
.
zeros
(
*
target
.
size
()).
to
(
target
)
one_hots
.
scatter_
(
1
,
logits
.
view
(
-
1
,
1
),
1
)
scores
=
one_hots
*
target
self
.
score
+=
scores
.
sum
()
self
.
total
+=
len
(
logits
)
def
compute
(
self
):
return
self
.
score
/
self
.
total
vmlo/vlmo/modules/__init__.py
0 → 100644
View file @
c501623c
from
.vlmo_module
import
VLMo
vmlo/vlmo/modules/dist_utils.py
0 → 100644
View file @
c501623c
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
This file contains primitives for multi-gpu communication.
This is useful when doing distributed training.
"""
import
functools
import
logging
import
numpy
as
np
import
pickle
import
torch
import
torch.distributed
as
dist
import
torch
_LOCAL_PROCESS_GROUP
=
None
"""
A torch process group which only includes processes that on the same machine as the current process.
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
"""
def
get_world_size
()
->
int
:
if
not
dist
.
is_available
():
return
1
if
not
dist
.
is_initialized
():
return
1
return
dist
.
get_world_size
()
def
get_rank
()
->
int
:
if
not
dist
.
is_available
():
return
0
if
not
dist
.
is_initialized
():
return
0
return
dist
.
get_rank
()
def
get_local_rank
()
->
int
:
"""
Returns:
The rank of the current process within the local (per-machine) process group.
"""
if
not
dist
.
is_available
():
return
0
if
not
dist
.
is_initialized
():
return
0
assert
_LOCAL_PROCESS_GROUP
is
not
None
return
dist
.
get_rank
(
group
=
_LOCAL_PROCESS_GROUP
)
def
get_local_size
()
->
int
:
"""
Returns:
The size of the per-machine process group,
i.e. the number of processes per machine.
"""
if
not
dist
.
is_available
():
return
1
if
not
dist
.
is_initialized
():
return
1
return
dist
.
get_world_size
(
group
=
_LOCAL_PROCESS_GROUP
)
def
is_main_process
()
->
bool
:
return
get_rank
()
==
0
def
synchronize
():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if
not
dist
.
is_available
():
return
if
not
dist
.
is_initialized
():
return
world_size
=
dist
.
get_world_size
()
if
world_size
==
1
:
return
dist
.
barrier
()
@
functools
.
lru_cache
()
def
_get_global_gloo_group
():
"""
Return a process group based on gloo backend, containing all the ranks
The result is cached.
"""
if
dist
.
get_backend
()
==
"nccl"
:
return
dist
.
new_group
(
backend
=
"gloo"
)
else
:
return
dist
.
group
.
WORLD
def
_serialize_to_tensor
(
data
,
group
):
backend
=
dist
.
get_backend
(
group
)
assert
backend
in
[
"gloo"
,
"nccl"
]
device
=
torch
.
device
(
"cpu"
if
backend
==
"gloo"
else
"cuda"
)
buffer
=
pickle
.
dumps
(
data
)
if
len
(
buffer
)
>
1024
**
3
:
logger
=
logging
.
getLogger
(
__name__
)
logger
.
warning
(
"Rank {} trying to all-gather {:.2f} GB of data on device {}"
.
format
(
get_rank
(),
len
(
buffer
)
/
(
1024
**
3
),
device
)
)
storage
=
torch
.
ByteStorage
.
from_buffer
(
buffer
)
tensor
=
torch
.
ByteTensor
(
storage
).
to
(
device
=
device
)
return
tensor
def
_pad_to_largest_tensor
(
tensor
,
group
):
"""
Returns:
list[int]: size of the tensor, on each rank
Tensor: padded tensor that has the max size
"""
world_size
=
dist
.
get_world_size
(
group
=
group
)
assert
(
world_size
>=
1
),
"comm.gather/all_gather must be called from ranks within the given group!"
local_size
=
torch
.
tensor
([
tensor
.
numel
()],
dtype
=
torch
.
int64
,
device
=
tensor
.
device
)
size_list
=
[
torch
.
zeros
([
1
],
dtype
=
torch
.
int64
,
device
=
tensor
.
device
)
for
_
in
range
(
world_size
)
]
dist
.
all_gather
(
size_list
,
local_size
,
group
=
group
)
size_list
=
[
int
(
size
.
item
())
for
size
in
size_list
]
max_size
=
max
(
size_list
)
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
if
local_size
!=
max_size
:
padding
=
torch
.
zeros
(
(
max_size
-
local_size
,),
dtype
=
torch
.
uint8
,
device
=
tensor
.
device
)
tensor
=
torch
.
cat
((
tensor
,
padding
),
dim
=
0
)
return
size_list
,
tensor
def
all_gather
(
data
,
group
=
None
):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: list of data gathered from each rank
"""
if
get_world_size
()
==
1
:
return
[
data
]
if
group
is
None
:
group
=
_get_global_gloo_group
()
if
dist
.
get_world_size
(
group
)
==
1
:
return
[
data
]
tensor
=
_serialize_to_tensor
(
data
,
group
)
size_list
,
tensor
=
_pad_to_largest_tensor
(
tensor
,
group
)
max_size
=
max
(
size_list
)
# receiving Tensor from all ranks
tensor_list
=
[
torch
.
empty
((
max_size
,),
dtype
=
torch
.
uint8
,
device
=
tensor
.
device
)
for
_
in
size_list
]
dist
.
all_gather
(
tensor_list
,
tensor
,
group
=
group
)
data_list
=
[]
for
size
,
tensor
in
zip
(
size_list
,
tensor_list
):
buffer
=
tensor
.
cpu
().
numpy
().
tobytes
()[:
size
]
data_list
.
append
(
pickle
.
loads
(
buffer
))
return
data_list
def
gather
(
data
,
dst
=
0
,
group
=
None
):
"""
Run gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
dst (int): destination rank
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: on dst, a list of data gathered from each rank. Otherwise,
an empty list.
"""
if
get_world_size
()
==
1
:
return
[
data
]
if
group
is
None
:
group
=
_get_global_gloo_group
()
if
dist
.
get_world_size
(
group
=
group
)
==
1
:
return
[
data
]
rank
=
dist
.
get_rank
(
group
=
group
)
tensor
=
_serialize_to_tensor
(
data
,
group
)
size_list
,
tensor
=
_pad_to_largest_tensor
(
tensor
,
group
)
# receiving Tensor from all ranks
if
rank
==
dst
:
max_size
=
max
(
size_list
)
tensor_list
=
[
torch
.
empty
((
max_size
,),
dtype
=
torch
.
uint8
,
device
=
tensor
.
device
)
for
_
in
size_list
]
dist
.
gather
(
tensor
,
tensor_list
,
dst
=
dst
,
group
=
group
)
data_list
=
[]
for
size
,
tensor
in
zip
(
size_list
,
tensor_list
):
buffer
=
tensor
.
cpu
().
numpy
().
tobytes
()[:
size
]
data_list
.
append
(
pickle
.
loads
(
buffer
))
return
data_list
else
:
dist
.
gather
(
tensor
,
[],
dst
=
dst
,
group
=
group
)
return
[]
def
shared_random_seed
():
"""
Returns:
int: a random number that is the same across all workers.
If workers need a shared RNG, they can use this shared seed to
create one.
All workers must call this function, otherwise it will deadlock.
"""
ints
=
np
.
random
.
randint
(
2
**
31
)
all_ints
=
all_gather
(
ints
)
return
all_ints
[
0
]
def
reduce_dict
(
input_dict
,
average
=
True
):
"""
Reduce the values in the dictionary from all processes so that process with rank
0 has the reduced results.
Args:
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
average (bool): whether to do average or sum
Returns:
a dict with the same keys as input_dict, after reduction.
"""
world_size
=
get_world_size
()
if
world_size
<
2
:
return
input_dict
with
torch
.
no_grad
():
names
=
[]
values
=
[]
# sort the keys so that they are consistent across processes
for
k
in
sorted
(
input_dict
.
keys
()):
names
.
append
(
k
)
values
.
append
(
input_dict
[
k
])
values
=
torch
.
stack
(
values
,
dim
=
0
)
dist
.
reduce
(
values
,
dst
=
0
)
if
dist
.
get_rank
()
==
0
and
average
:
# only main process gets accumulated, so only divide by
# world_size in this case
values
/=
world_size
reduced_dict
=
{
k
:
v
for
k
,
v
in
zip
(
names
,
values
)}
return
reduced_dict
vmlo/vlmo/modules/heads.py
0 → 100644
View file @
c501623c
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
transformers.models.bert.modeling_bert
import
BertPredictionHeadTransform
class
Pooler
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
activation
=
nn
.
Tanh
()
def
forward
(
self
,
hidden_states
):
first_token_tensor
=
hidden_states
[:,
0
]
pooled_output
=
self
.
dense
(
first_token_tensor
)
pooled_output
=
self
.
activation
(
pooled_output
)
return
pooled_output
class
ITMHead
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
super
().
__init__
()
self
.
fc
=
nn
.
Linear
(
hidden_size
,
2
)
def
forward
(
self
,
x
):
x
=
self
.
fc
(
x
)
return
x
class
ITCHead
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
):
super
().
__init__
()
self
.
fc
=
nn
.
Linear
(
hidden_size
,
hidden_size
,
bias
=
False
)
def
forward
(
self
,
x
):
x
=
self
.
fc
(
x
)
return
x
class
MLMHead
(
nn
.
Module
):
def
__init__
(
self
,
config
,
weight
=
None
):
super
().
__init__
()
self
.
transform
=
BertPredictionHeadTransform
(
config
)
self
.
decoder
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
config
.
vocab_size
))
if
weight
is
not
None
:
self
.
decoder
.
weight
=
weight
def
forward
(
self
,
x
):
x
=
self
.
transform
(
x
)
x
=
self
.
decoder
(
x
)
+
self
.
bias
return
x
vmlo/vlmo/modules/multiway_transformer.py
0 → 100644
View file @
c501623c
""" Vision Transformer (ViT) in PyTorch
A PyTorch implement of Vision Transformers as described in
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
The official jax code is released and available at https://github.com/google-research/vision_transformer
Acknowledgments:
* The paper authors for releasing code and weights, thanks!
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
for some einops/einsum fun
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
DeiT model defs and weights from https://github.com/facebookresearch/deit,
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
Hacked together by / Copyright 2020 Ross Wightman
"""
from
functools
import
partial
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
functools
import
partial
from
timm.models.layers
import
DropPath
,
to_2tuple
,
trunc_normal_
from
timm.models.registry
import
register_model
from
pytorch_lightning.utilities.distributed
import
rank_zero_info
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.0
,
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.0
,
proj_drop
=
0.0
,
):
super
().
__init__
()
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
False
)
if
qkv_bias
:
self
.
q_bias
=
nn
.
Parameter
(
torch
.
zeros
(
dim
))
self
.
v_bias
=
nn
.
Parameter
(
torch
.
zeros
(
dim
))
else
:
self
.
q_bias
=
None
self
.
v_bias
=
None
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
,
mask
=
None
,
relative_position_bias
=
None
):
B
,
N
,
C
=
x
.
shape
qkv_bias
=
None
if
self
.
q_bias
is
not
None
:
qkv_bias
=
torch
.
cat
((
self
.
q_bias
,
torch
.
zeros_like
(
self
.
v_bias
,
requires_grad
=
False
),
self
.
v_bias
))
qkv
=
F
.
linear
(
input
=
x
,
weight
=
self
.
qkv
.
weight
,
bias
=
qkv_bias
)
qkv
=
qkv
.
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
-
1
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
(
qkv
[
0
],
qkv
[
1
],
qkv
[
2
],
)
# make torchscript happy (cannot use tensor as tuple)
q
=
q
*
self
.
scale
attn
=
(
q
.
float
()
@
k
.
float
().
transpose
(
-
2
,
-
1
))
if
relative_position_bias
is
not
None
:
attn
=
attn
+
relative_position_bias
.
unsqueeze
(
0
)
if
mask
is
not
None
:
mask
=
mask
.
bool
()
attn
=
attn
.
masked_fill
(
~
mask
[:,
None
,
None
,
:],
float
(
"-inf"
))
attn
=
attn
.
softmax
(
dim
=-
1
).
type_as
(
x
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.0
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.0
,
attn_drop
=
0.0
,
drop_path
=
0.0
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
,
with_vlffn
=
False
,
layer_scale_init_values
=
0.1
,
max_text_len
=
40
,
):
super
().
__init__
()
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
Attention
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
norm2_text
=
norm_layer
(
dim
)
self
.
norm2_imag
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp_text
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
,
)
self
.
mlp_imag
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
,
)
self
.
mlp_vl
=
None
if
with_vlffn
:
self
.
mlp_vl
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
,
)
self
.
norm2_vl
=
norm_layer
(
dim
)
self
.
gamma_1
=
\
nn
.
Parameter
(
layer_scale_init_values
*
torch
.
ones
((
dim
)),
requires_grad
=
True
)
\
if
layer_scale_init_values
is
not
None
else
1.0
self
.
gamma_2
=
\
nn
.
Parameter
(
layer_scale_init_values
*
torch
.
ones
((
dim
)),
requires_grad
=
True
)
\
if
layer_scale_init_values
is
not
None
else
1.0
self
.
max_text_len
=
max_text_len
def
forward
(
self
,
x
,
mask
=
None
,
modality_type
=
None
,
relative_position_bias
=
None
):
x
=
x
+
self
.
drop_path
(
self
.
gamma_1
*
self
.
attn
(
self
.
norm1
(
x
),
mask
=
mask
,
relative_position_bias
=
relative_position_bias
))
if
modality_type
==
"image"
:
x
=
x
+
self
.
drop_path
(
self
.
gamma_2
*
self
.
mlp_imag
(
self
.
norm2_imag
(
x
)))
elif
modality_type
==
"text"
:
x
=
x
+
self
.
drop_path
(
self
.
gamma_2
*
self
.
mlp_text
(
self
.
norm2_text
(
x
)))
else
:
if
self
.
mlp_vl
is
None
:
x_text
=
x
[:,
:
self
.
max_text_len
]
x_imag
=
x
[:,
self
.
max_text_len
:]
x_text
=
x_text
+
self
.
drop_path
(
self
.
gamma_2
*
self
.
mlp_text
(
self
.
norm2_text
(
x_text
)))
x_imag
=
x_imag
+
self
.
drop_path
(
self
.
gamma_2
*
self
.
mlp_imag
(
self
.
norm2_imag
(
x_imag
)))
x
=
torch
.
cat
([
x_text
,
x_imag
],
dim
=
1
)
else
:
x
=
x
+
self
.
drop_path
(
self
.
gamma_2
*
self
.
mlp_vl
(
self
.
norm2_vl
(
x
)))
return
x
class
PatchEmbed
(
nn
.
Module
):
""" Image to Patch Embedding"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
,
no_patch_embed_bias
=
False
,
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
num_patches
=
(
img_size
[
1
]
//
patch_size
[
1
])
*
(
img_size
[
0
]
//
patch_size
[
0
])
self
.
patch_shape
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
num_patches
=
num_patches
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
,
bias
=
False
if
no_patch_embed_bias
else
True
,
)
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
# FIXME look at relaxing size constraints
x
=
self
.
proj
(
x
)
return
x
class
MultiWayTransformer
(
nn
.
Module
):
""" Vision Transformer
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
https://arxiv.org/abs/2010.11929
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4.0
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop_rate
=
0.0
,
attn_drop_rate
=
0.0
,
drop_path_rate
=
0.0
,
norm_layer
=
None
,
need_relative_position_embed
=
True
,
use_abs_pos_emb
=
False
,
layer_scale_init_values
=
0.1
,
vlffn_start_layer_index
=
10
,
config
=
None
,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
norm_layer: (nn.Module): normalization layer
need_relative_position_embed (bool): enable relative position bias on self-attention
use_abs_pos_emb (bool): enable abs pos emb
layer_scale_init_values (float or None): layer scale init values, set None to disable
vlffn_start_layer_index (int): vl-ffn start index
config: (dict): other hyper from pytorch-lighting
"""
super
().
__init__
()
drop_path_rate
=
drop_path_rate
if
config
is
None
else
config
[
"drop_path_rate"
]
rank_zero_info
(
"drop path rate: {}"
.
format
(
drop_path_rate
))
self
.
use_abs_pos_emb
=
use_abs_pos_emb
self
.
need_relative_position_embed
=
need_relative_position_embed
self
.
num_features
=
(
self
.
embed_dim
)
=
embed_dim
# num_features for consistency with other models
norm_layer
=
norm_layer
or
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
)
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
patch_size
=
patch_size
self
.
num_heads
=
num_heads
self
.
vlffn_start_layer_index
=
vlffn_start_layer_index
if
config
[
"loss_names"
][
"textmlm"
]
>
0
:
self
.
vlffn_start_layer_index
=
depth
rank_zero_info
(
"Set vlffn_start_layer_index={} for text-only pretraining"
.
format
(
self
.
vlffn_start_layer_index
))
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
+
1
,
embed_dim
))
if
self
.
use_abs_pos_emb
else
None
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
depth
)
]
# stochastic depth decay rule
self
.
blocks
=
nn
.
ModuleList
(
[
Block
(
dim
=
embed_dim
,
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
i
],
norm_layer
=
norm_layer
,
with_vlffn
=
(
i
>=
self
.
vlffn_start_layer_index
),
layer_scale_init_values
=
layer_scale_init_values
,
max_text_len
=
config
[
"max_text_len"
],
)
for
i
in
range
(
depth
)
]
)
self
.
norm
=
norm_layer
(
embed_dim
)
if
self
.
pos_embed
is
not
None
:
trunc_normal_
(
self
.
pos_embed
,
std
=
0.02
)
trunc_normal_
(
self
.
cls_token
,
std
=
0.02
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
0.02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
"pos_embed"
,
"cls_token"
}
def
visual_embed
(
self
,
_x
):
x
=
self
.
patch_embed
(
_x
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
B
,
L
,
_
=
x
.
shape
cls_tokens
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
if
self
.
pos_embed
is
not
None
:
x
=
x
+
self
.
pos_embed
x
=
self
.
pos_drop
(
x
)
x_mask
=
torch
.
ones
(
x
.
shape
[
0
],
x
.
shape
[
1
])
return
x
,
x_mask
# VLMo base/p16
@
register_model
def
vlmo_base_patch16
(
pretrained
=
False
,
**
kwargs
):
img_size
=
kwargs
.
pop
(
"img_size"
,
224
)
model
=
MultiWayTransformer
(
img_size
=
img_size
,
patch_size
=
16
,
embed_dim
=
768
,
depth
=
12
,
num_heads
=
12
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
vlffn_start_layer_index
=
10
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
**
kwargs
)
return
model
# VLMo large/p16
@
register_model
def
vlmo_large_patch16
(
pretrained
=
False
,
**
kwargs
):
img_size
=
kwargs
.
pop
(
"img_size"
,
224
)
model
=
MultiWayTransformer
(
img_size
=
img_size
,
patch_size
=
16
,
embed_dim
=
1024
,
depth
=
24
,
num_heads
=
16
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
vlffn_start_layer_index
=
21
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
**
kwargs
)
return
model
# VLMo base+/p16
@
register_model
def
vlmo_base_plus_patch16
(
pretrained
=
False
,
**
kwargs
):
img_size
=
kwargs
.
pop
(
"img_size"
,
224
)
model
=
MultiWayTransformer
(
img_size
=
img_size
,
patch_size
=
16
,
embed_dim
=
544
,
depth
=
24
,
num_heads
=
16
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
vlffn_start_layer_index
=
21
,
use_abs_pos_emb
=
True
,
need_relative_position_embed
=
False
,
layer_scale_init_values
=
None
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
**
kwargs
)
return
model
vmlo/vlmo/modules/objectives.py
0 → 100644
View file @
c501623c
This diff is collapsed.
Click to expand it.
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment