Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Matting-Anything_pytorch
Commits
ce0e5303
Commit
ce0e5303
authored
Nov 28, 2024
by
bailuo
Browse files
init
parents
Pipeline
#2003
failed with stages
in 0 seconds
Changes
172
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1744 additions
and
0 deletions
+1744
-0
GroundingDINO/groundingdino/models/__init__.py
GroundingDINO/groundingdino/models/__init__.py
+18
-0
GroundingDINO/groundingdino/models/__pycache__/__init__.cpython-310.pyc
...groundingdino/models/__pycache__/__init__.cpython-310.pyc
+0
-0
GroundingDINO/groundingdino/models/__pycache__/registry.cpython-310.pyc
...groundingdino/models/__pycache__/registry.cpython-310.pyc
+0
-0
GroundingDINO/groundingdino/models/registry.py
GroundingDINO/groundingdino/models/registry.py
+66
-0
GroundingDINO/groundingdino/util/__init__.py
GroundingDINO/groundingdino/util/__init__.py
+1
-0
GroundingDINO/groundingdino/util/__pycache__/__init__.cpython-310.pyc
...O/groundingdino/util/__pycache__/__init__.cpython-310.pyc
+0
-0
GroundingDINO/groundingdino/util/__pycache__/box_ops.cpython-310.pyc
...NO/groundingdino/util/__pycache__/box_ops.cpython-310.pyc
+0
-0
GroundingDINO/groundingdino/util/__pycache__/get_tokenlizer.cpython-310.pyc
...ndingdino/util/__pycache__/get_tokenlizer.cpython-310.pyc
+0
-0
GroundingDINO/groundingdino/util/__pycache__/inference.cpython-310.pyc
.../groundingdino/util/__pycache__/inference.cpython-310.pyc
+0
-0
GroundingDINO/groundingdino/util/__pycache__/misc.cpython-310.pyc
...gDINO/groundingdino/util/__pycache__/misc.cpython-310.pyc
+0
-0
GroundingDINO/groundingdino/util/__pycache__/slconfig.cpython-310.pyc
...O/groundingdino/util/__pycache__/slconfig.cpython-310.pyc
+0
-0
GroundingDINO/groundingdino/util/__pycache__/utils.cpython-310.pyc
...DINO/groundingdino/util/__pycache__/utils.cpython-310.pyc
+0
-0
GroundingDINO/groundingdino/util/__pycache__/visualizer.cpython-310.pyc
...groundingdino/util/__pycache__/visualizer.cpython-310.pyc
+0
-0
GroundingDINO/groundingdino/util/__pycache__/vl_utils.cpython-310.pyc
...O/groundingdino/util/__pycache__/vl_utils.cpython-310.pyc
+0
-0
GroundingDINO/groundingdino/util/box_ops.py
GroundingDINO/groundingdino/util/box_ops.py
+140
-0
GroundingDINO/groundingdino/util/get_tokenlizer.py
GroundingDINO/groundingdino/util/get_tokenlizer.py
+26
-0
GroundingDINO/groundingdino/util/inference.py
GroundingDINO/groundingdino/util/inference.py
+256
-0
GroundingDINO/groundingdino/util/logger.py
GroundingDINO/groundingdino/util/logger.py
+93
-0
GroundingDINO/groundingdino/util/misc.py
GroundingDINO/groundingdino/util/misc.py
+717
-0
GroundingDINO/groundingdino/util/slconfig.py
GroundingDINO/groundingdino/util/slconfig.py
+427
-0
No files found.
GroundingDINO/groundingdino/models/__init__.py
0 → 100644
View file @
ce0e5303
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from
.GroundingDINO
import
build_groundingdino
def
build_model
(
args
):
# we use register to maintain models from catdet6 on.
from
.registry
import
MODULE_BUILD_FUNCS
assert
args
.
modelname
in
MODULE_BUILD_FUNCS
.
_module_dict
build_func
=
MODULE_BUILD_FUNCS
.
get
(
args
.
modelname
)
model
=
build_func
(
args
)
return
model
GroundingDINO/groundingdino/models/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
ce0e5303
File added
GroundingDINO/groundingdino/models/__pycache__/registry.cpython-310.pyc
0 → 100644
View file @
ce0e5303
File added
GroundingDINO/groundingdino/models/registry.py
0 → 100644
View file @
ce0e5303
# ------------------------------------------------------------------------
# Grounding DINO
# url: https://github.com/IDEA-Research/GroundingDINO
# Copyright (c) 2023 IDEA. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------
# -*- coding: utf-8 -*-
# @Author: Yihao Chen
# @Date: 2021-08-16 16:03:17
# @Last Modified by: Shilong Liu
# @Last Modified time: 2022-01-23 15:26
# modified from mmcv
import
inspect
from
functools
import
partial
class
Registry
(
object
):
def
__init__
(
self
,
name
):
self
.
_name
=
name
self
.
_module_dict
=
dict
()
def
__repr__
(
self
):
format_str
=
self
.
__class__
.
__name__
+
"(name={}, items={})"
.
format
(
self
.
_name
,
list
(
self
.
_module_dict
.
keys
())
)
return
format_str
def
__len__
(
self
):
return
len
(
self
.
_module_dict
)
@
property
def
name
(
self
):
return
self
.
_name
@
property
def
module_dict
(
self
):
return
self
.
_module_dict
def
get
(
self
,
key
):
return
self
.
_module_dict
.
get
(
key
,
None
)
def
registe_with_name
(
self
,
module_name
=
None
,
force
=
False
):
return
partial
(
self
.
register
,
module_name
=
module_name
,
force
=
force
)
def
register
(
self
,
module_build_function
,
module_name
=
None
,
force
=
False
):
"""Register a module build function.
Args:
module (:obj:`nn.Module`): Module to be registered.
"""
if
not
inspect
.
isfunction
(
module_build_function
):
raise
TypeError
(
"module_build_function must be a function, but got {}"
.
format
(
type
(
module_build_function
)
)
)
if
module_name
is
None
:
module_name
=
module_build_function
.
__name__
if
not
force
and
module_name
in
self
.
_module_dict
:
raise
KeyError
(
"{} is already registered in {}"
.
format
(
module_name
,
self
.
name
))
self
.
_module_dict
[
module_name
]
=
module_build_function
return
module_build_function
MODULE_BUILD_FUNCS
=
Registry
(
"model build functions"
)
GroundingDINO/groundingdino/util/__init__.py
0 → 100644
View file @
ce0e5303
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
GroundingDINO/groundingdino/util/__pycache__/__init__.cpython-310.pyc
0 → 100644
View file @
ce0e5303
File added
GroundingDINO/groundingdino/util/__pycache__/box_ops.cpython-310.pyc
0 → 100644
View file @
ce0e5303
File added
GroundingDINO/groundingdino/util/__pycache__/get_tokenlizer.cpython-310.pyc
0 → 100644
View file @
ce0e5303
File added
GroundingDINO/groundingdino/util/__pycache__/inference.cpython-310.pyc
0 → 100644
View file @
ce0e5303
File added
GroundingDINO/groundingdino/util/__pycache__/misc.cpython-310.pyc
0 → 100644
View file @
ce0e5303
File added
GroundingDINO/groundingdino/util/__pycache__/slconfig.cpython-310.pyc
0 → 100644
View file @
ce0e5303
File added
GroundingDINO/groundingdino/util/__pycache__/utils.cpython-310.pyc
0 → 100644
View file @
ce0e5303
File added
GroundingDINO/groundingdino/util/__pycache__/visualizer.cpython-310.pyc
0 → 100644
View file @
ce0e5303
File added
GroundingDINO/groundingdino/util/__pycache__/vl_utils.cpython-310.pyc
0 → 100644
View file @
ce0e5303
File added
GroundingDINO/groundingdino/util/box_ops.py
0 → 100644
View file @
ce0e5303
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Utilities for bounding box manipulation and GIoU.
"""
import
torch
from
torchvision.ops.boxes
import
box_area
def
box_cxcywh_to_xyxy
(
x
):
x_c
,
y_c
,
w
,
h
=
x
.
unbind
(
-
1
)
b
=
[(
x_c
-
0.5
*
w
),
(
y_c
-
0.5
*
h
),
(
x_c
+
0.5
*
w
),
(
y_c
+
0.5
*
h
)]
return
torch
.
stack
(
b
,
dim
=-
1
)
def
box_xyxy_to_cxcywh
(
x
):
x0
,
y0
,
x1
,
y1
=
x
.
unbind
(
-
1
)
b
=
[(
x0
+
x1
)
/
2
,
(
y0
+
y1
)
/
2
,
(
x1
-
x0
),
(
y1
-
y0
)]
return
torch
.
stack
(
b
,
dim
=-
1
)
# modified from torchvision to also return the union
def
box_iou
(
boxes1
,
boxes2
):
area1
=
box_area
(
boxes1
)
area2
=
box_area
(
boxes2
)
# import ipdb; ipdb.set_trace()
lt
=
torch
.
max
(
boxes1
[:,
None
,
:
2
],
boxes2
[:,
:
2
])
# [N,M,2]
rb
=
torch
.
min
(
boxes1
[:,
None
,
2
:],
boxes2
[:,
2
:])
# [N,M,2]
wh
=
(
rb
-
lt
).
clamp
(
min
=
0
)
# [N,M,2]
inter
=
wh
[:,
:,
0
]
*
wh
[:,
:,
1
]
# [N,M]
union
=
area1
[:,
None
]
+
area2
-
inter
iou
=
inter
/
(
union
+
1e-6
)
return
iou
,
union
def
generalized_box_iou
(
boxes1
,
boxes2
):
"""
Generalized IoU from https://giou.stanford.edu/
The boxes should be in [x0, y0, x1, y1] format
Returns a [N, M] pairwise matrix, where N = len(boxes1)
and M = len(boxes2)
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert
(
boxes1
[:,
2
:]
>=
boxes1
[:,
:
2
]).
all
()
assert
(
boxes2
[:,
2
:]
>=
boxes2
[:,
:
2
]).
all
()
# except:
# import ipdb; ipdb.set_trace()
iou
,
union
=
box_iou
(
boxes1
,
boxes2
)
lt
=
torch
.
min
(
boxes1
[:,
None
,
:
2
],
boxes2
[:,
:
2
])
rb
=
torch
.
max
(
boxes1
[:,
None
,
2
:],
boxes2
[:,
2
:])
wh
=
(
rb
-
lt
).
clamp
(
min
=
0
)
# [N,M,2]
area
=
wh
[:,
:,
0
]
*
wh
[:,
:,
1
]
return
iou
-
(
area
-
union
)
/
(
area
+
1e-6
)
# modified from torchvision to also return the union
def
box_iou_pairwise
(
boxes1
,
boxes2
):
area1
=
box_area
(
boxes1
)
area2
=
box_area
(
boxes2
)
lt
=
torch
.
max
(
boxes1
[:,
:
2
],
boxes2
[:,
:
2
])
# [N,2]
rb
=
torch
.
min
(
boxes1
[:,
2
:],
boxes2
[:,
2
:])
# [N,2]
wh
=
(
rb
-
lt
).
clamp
(
min
=
0
)
# [N,2]
inter
=
wh
[:,
0
]
*
wh
[:,
1
]
# [N]
union
=
area1
+
area2
-
inter
iou
=
inter
/
union
return
iou
,
union
def
generalized_box_iou_pairwise
(
boxes1
,
boxes2
):
"""
Generalized IoU from https://giou.stanford.edu/
Input:
- boxes1, boxes2: N,4
Output:
- giou: N, 4
"""
# degenerate boxes gives inf / nan results
# so do an early check
assert
(
boxes1
[:,
2
:]
>=
boxes1
[:,
:
2
]).
all
()
assert
(
boxes2
[:,
2
:]
>=
boxes2
[:,
:
2
]).
all
()
assert
boxes1
.
shape
==
boxes2
.
shape
iou
,
union
=
box_iou_pairwise
(
boxes1
,
boxes2
)
# N, 4
lt
=
torch
.
min
(
boxes1
[:,
:
2
],
boxes2
[:,
:
2
])
rb
=
torch
.
max
(
boxes1
[:,
2
:],
boxes2
[:,
2
:])
wh
=
(
rb
-
lt
).
clamp
(
min
=
0
)
# [N,2]
area
=
wh
[:,
0
]
*
wh
[:,
1
]
return
iou
-
(
area
-
union
)
/
area
def
masks_to_boxes
(
masks
):
"""Compute the bounding boxes around the provided masks
The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
Returns a [N, 4] tensors, with the boxes in xyxy format
"""
if
masks
.
numel
()
==
0
:
return
torch
.
zeros
((
0
,
4
),
device
=
masks
.
device
)
h
,
w
=
masks
.
shape
[
-
2
:]
y
=
torch
.
arange
(
0
,
h
,
dtype
=
torch
.
float
)
x
=
torch
.
arange
(
0
,
w
,
dtype
=
torch
.
float
)
y
,
x
=
torch
.
meshgrid
(
y
,
x
)
x_mask
=
masks
*
x
.
unsqueeze
(
0
)
x_max
=
x_mask
.
flatten
(
1
).
max
(
-
1
)[
0
]
x_min
=
x_mask
.
masked_fill
(
~
(
masks
.
bool
()),
1e8
).
flatten
(
1
).
min
(
-
1
)[
0
]
y_mask
=
masks
*
y
.
unsqueeze
(
0
)
y_max
=
y_mask
.
flatten
(
1
).
max
(
-
1
)[
0
]
y_min
=
y_mask
.
masked_fill
(
~
(
masks
.
bool
()),
1e8
).
flatten
(
1
).
min
(
-
1
)[
0
]
return
torch
.
stack
([
x_min
,
y_min
,
x_max
,
y_max
],
1
)
if
__name__
==
"__main__"
:
x
=
torch
.
rand
(
5
,
4
)
y
=
torch
.
rand
(
3
,
4
)
iou
,
union
=
box_iou
(
x
,
y
)
import
ipdb
ipdb
.
set_trace
()
GroundingDINO/groundingdino/util/get_tokenlizer.py
0 → 100644
View file @
ce0e5303
from
transformers
import
AutoTokenizer
,
BertModel
,
BertTokenizer
,
RobertaModel
,
RobertaTokenizerFast
def
get_tokenlizer
(
text_encoder_type
):
if
not
isinstance
(
text_encoder_type
,
str
):
# print("text_encoder_type is not a str")
if
hasattr
(
text_encoder_type
,
"text_encoder_type"
):
text_encoder_type
=
text_encoder_type
.
text_encoder_type
elif
text_encoder_type
.
get
(
"text_encoder_type"
,
False
):
text_encoder_type
=
text_encoder_type
.
get
(
"text_encoder_type"
)
else
:
raise
ValueError
(
"Unknown type of text_encoder_type: {}"
.
format
(
type
(
text_encoder_type
))
)
print
(
"final text_encoder_type: {}"
.
format
(
text_encoder_type
))
tokenizer
=
AutoTokenizer
.
from_pretrained
(
text_encoder_type
)
return
tokenizer
def
get_pretrained_language_model
(
text_encoder_type
):
if
text_encoder_type
==
"bert-base-uncased"
:
return
BertModel
.
from_pretrained
(
text_encoder_type
)
if
text_encoder_type
==
"roberta-base"
:
return
RobertaModel
.
from_pretrained
(
text_encoder_type
)
raise
ValueError
(
"Unknown text_encoder_type {}"
.
format
(
text_encoder_type
))
GroundingDINO/groundingdino/util/inference.py
0 → 100644
View file @
ce0e5303
from
typing
import
Tuple
,
List
import
cv2
import
numpy
as
np
import
supervision
as
sv
import
torch
from
PIL
import
Image
from
torchvision.ops
import
box_convert
import
groundingdino.datasets.transforms
as
T
from
groundingdino.models
import
build_model
from
groundingdino.util.misc
import
clean_state_dict
from
groundingdino.util.slconfig
import
SLConfig
from
groundingdino.util.utils
import
get_phrases_from_posmap
import
pdb
# ----------------------------------------------------------------------------------------------------------------------
# OLD API
# ----------------------------------------------------------------------------------------------------------------------
def
preprocess_caption
(
caption
:
str
)
->
str
:
result
=
caption
.
lower
().
strip
()
if
result
.
endswith
(
"."
):
return
result
return
result
+
"."
def
load_model
(
model_config_path
:
str
,
model_checkpoint_path
:
str
,
device
:
str
=
"cuda"
):
args
=
SLConfig
.
fromfile
(
model_config_path
)
args
.
device
=
device
model
=
build_model
(
args
)
checkpoint
=
torch
.
load
(
model_checkpoint_path
,
map_location
=
"cpu"
)
model
.
load_state_dict
(
clean_state_dict
(
checkpoint
[
"model"
]),
strict
=
False
)
model
.
eval
()
return
model
def
load_image
(
image_path
:
str
)
->
Tuple
[
np
.
array
,
torch
.
Tensor
]:
transform
=
T
.
Compose
(
[
T
.
RandomResize
([
800
],
max_size
=
1333
),
T
.
ToTensor
(),
T
.
Normalize
([
0.485
,
0.456
,
0.406
],
[
0.229
,
0.224
,
0.225
]),
]
)
image_source
=
Image
.
open
(
image_path
).
convert
(
"RGB"
)
image
=
np
.
asarray
(
image_source
)
image_transformed
,
_
=
transform
(
image_source
,
None
)
return
image
,
image_transformed
def
predict
(
model
,
image
:
torch
.
Tensor
,
caption
:
str
,
box_threshold
:
float
,
text_threshold
:
float
,
device
:
str
=
"cuda"
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
List
[
str
]]:
caption
=
preprocess_caption
(
caption
=
caption
)
model
=
model
.
to
(
device
)
image
=
image
.
to
(
device
)
with
torch
.
no_grad
():
outputs
=
model
(
image
[
None
],
captions
=
[
caption
])
prediction_logits
=
outputs
[
"pred_logits"
].
cpu
().
sigmoid
()[
0
]
# prediction_logits.shape = (nq, 256)
prediction_boxes
=
outputs
[
"pred_boxes"
].
cpu
()[
0
]
# prediction_boxes.shape = (nq, 4)
mask
=
prediction_logits
.
max
(
dim
=
1
)[
0
]
>
box_threshold
logits
=
prediction_logits
[
mask
]
# logits.shape = (n, 256)
boxes
=
prediction_boxes
[
mask
]
# boxes.shape = (n, 4)
tokenizer
=
model
.
tokenizer
tokenized
=
tokenizer
(
caption
)
phrases
=
[
get_phrases_from_posmap
(
logit
>
text_threshold
,
tokenized
,
tokenizer
).
replace
(
'.'
,
''
)
for
logit
in
logits
]
return
boxes
,
logits
.
max
(
dim
=
1
)[
0
],
phrases
def
annotate
(
image_source
:
np
.
ndarray
,
boxes
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
phrases
:
List
[
str
])
->
np
.
ndarray
:
h
,
w
,
_
=
image_source
.
shape
boxes
=
boxes
*
torch
.
Tensor
([
w
,
h
,
w
,
h
])
xyxy
=
box_convert
(
boxes
=
boxes
,
in_fmt
=
"cxcywh"
,
out_fmt
=
"xyxy"
).
numpy
()
detections
=
sv
.
Detections
(
xyxy
=
xyxy
)
labels
=
[
f
"
{
phrase
}
{
logit
:.
2
f
}
"
for
phrase
,
logit
in
zip
(
phrases
,
logits
)
]
box_annotator
=
sv
.
BoxAnnotator
()
annotated_frame
=
cv2
.
cvtColor
(
image_source
,
cv2
.
COLOR_RGB2BGR
)
annotated_frame
=
box_annotator
.
annotate
(
scene
=
annotated_frame
,
detections
=
detections
,
labels
=
labels
)
return
annotated_frame
# ----------------------------------------------------------------------------------------------------------------------
# NEW API
# ----------------------------------------------------------------------------------------------------------------------
class
Model
:
def
__init__
(
self
,
model_config_path
:
str
,
model_checkpoint_path
:
str
,
device
:
str
=
"cuda"
):
self
.
model
=
load_model
(
model_config_path
=
model_config_path
,
model_checkpoint_path
=
model_checkpoint_path
,
device
=
device
).
to
(
device
)
self
.
device
=
device
def
predict_with_caption
(
self
,
image
:
np
.
ndarray
,
caption
:
str
,
box_threshold
:
float
=
0.35
,
text_threshold
:
float
=
0.25
)
->
Tuple
[
sv
.
Detections
,
List
[
str
]]:
"""
import cv2
image = cv2.imread(IMAGE_PATH)
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
detections, labels = model.predict_with_caption(
image=image,
caption=caption,
box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD
)
import supervision as sv
box_annotator = sv.BoxAnnotator()
annotated_image = box_annotator.annotate(scene=image, detections=detections, labels=labels)
"""
processed_image
=
Model
.
preprocess_image
(
image_bgr
=
image
).
to
(
self
.
device
)
boxes
,
logits
,
phrases
=
predict
(
model
=
self
.
model
,
image
=
processed_image
,
caption
=
caption
,
box_threshold
=
box_threshold
,
text_threshold
=
text_threshold
,
device
=
self
.
device
)
source_h
,
source_w
,
_
=
image
.
shape
detections
=
Model
.
post_process_result
(
source_h
=
source_h
,
source_w
=
source_w
,
boxes
=
boxes
,
logits
=
logits
)
return
detections
,
phrases
def
predict_with_classes
(
self
,
image
:
np
.
ndarray
,
classes
:
List
[
str
],
box_threshold
:
float
,
text_threshold
:
float
)
->
sv
.
Detections
:
"""
import cv2
image = cv2.imread(IMAGE_PATH)
model = Model(model_config_path=CONFIG_PATH, model_checkpoint_path=WEIGHTS_PATH)
detections = model.predict_with_classes(
image=image,
classes=CLASSES,
box_threshold=BOX_THRESHOLD,
text_threshold=TEXT_THRESHOLD
)
import supervision as sv
box_annotator = sv.BoxAnnotator()
annotated_image = box_annotator.annotate(scene=image, detections=detections)
"""
caption
=
". "
.
join
(
classes
)
processed_image
=
Model
.
preprocess_image
(
image_bgr
=
image
).
to
(
self
.
device
)
boxes
,
logits
,
phrases
=
predict
(
model
=
self
.
model
,
image
=
processed_image
,
caption
=
caption
,
box_threshold
=
box_threshold
,
text_threshold
=
text_threshold
,
device
=
self
.
device
)
source_h
,
source_w
,
_
=
image
.
shape
detections
=
Model
.
post_process_result
(
source_h
=
source_h
,
source_w
=
source_w
,
boxes
=
boxes
,
logits
=
logits
)
class_id
=
Model
.
phrases2classes
(
phrases
=
phrases
,
classes
=
classes
)
detections
.
class_id
=
class_id
return
detections
@
staticmethod
def
preprocess_image
(
image_bgr
:
np
.
ndarray
)
->
torch
.
Tensor
:
transform
=
T
.
Compose
(
[
T
.
RandomResize
([
800
],
max_size
=
1333
),
T
.
ToTensor
(),
T
.
Normalize
([
0.485
,
0.456
,
0.406
],
[
0.229
,
0.224
,
0.225
]),
]
)
image_pillow
=
Image
.
fromarray
(
cv2
.
cvtColor
(
image_bgr
,
cv2
.
COLOR_BGR2RGB
))
image_transformed
,
_
=
transform
(
image_pillow
,
None
)
return
image_transformed
@
staticmethod
def
post_process_result
(
source_h
:
int
,
source_w
:
int
,
boxes
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
)
->
sv
.
Detections
:
boxes
=
boxes
*
torch
.
Tensor
([
source_w
,
source_h
,
source_w
,
source_h
])
xyxy
=
box_convert
(
boxes
=
boxes
,
in_fmt
=
"cxcywh"
,
out_fmt
=
"xyxy"
).
numpy
()
confidence
=
logits
.
numpy
()
return
sv
.
Detections
(
xyxy
=
xyxy
,
confidence
=
confidence
)
@
staticmethod
def
phrases2classes
(
phrases
:
List
[
str
],
classes
:
List
[
str
])
->
np
.
ndarray
:
class_ids
=
[]
for
phrase
in
phrases
:
try
:
# class_ids.append(classes.index(phrase))
class_ids
.
append
(
Model
.
find_index
(
phrase
,
classes
))
except
ValueError
:
class_ids
.
append
(
None
)
return
np
.
array
(
class_ids
)
@
staticmethod
def
find_index
(
string
,
lst
):
# if meet string like "lake river" will only keep "lake"
# this is an hack implementation for visualization which will be updated in the future
string
=
string
.
lower
().
split
()[
0
]
for
i
,
s
in
enumerate
(
lst
):
if
string
in
s
.
lower
():
return
i
return
-
1
\ No newline at end of file
GroundingDINO/groundingdino/util/logger.py
0 → 100644
View file @
ce0e5303
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import
functools
import
logging
import
os
import
sys
from
termcolor
import
colored
class
_ColorfulFormatter
(
logging
.
Formatter
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
_root_name
=
kwargs
.
pop
(
"root_name"
)
+
"."
self
.
_abbrev_name
=
kwargs
.
pop
(
"abbrev_name"
,
""
)
if
len
(
self
.
_abbrev_name
):
self
.
_abbrev_name
=
self
.
_abbrev_name
+
"."
super
(
_ColorfulFormatter
,
self
).
__init__
(
*
args
,
**
kwargs
)
def
formatMessage
(
self
,
record
):
record
.
name
=
record
.
name
.
replace
(
self
.
_root_name
,
self
.
_abbrev_name
)
log
=
super
(
_ColorfulFormatter
,
self
).
formatMessage
(
record
)
if
record
.
levelno
==
logging
.
WARNING
:
prefix
=
colored
(
"WARNING"
,
"red"
,
attrs
=
[
"blink"
])
elif
record
.
levelno
==
logging
.
ERROR
or
record
.
levelno
==
logging
.
CRITICAL
:
prefix
=
colored
(
"ERROR"
,
"red"
,
attrs
=
[
"blink"
,
"underline"
])
else
:
return
log
return
prefix
+
" "
+
log
# so that calling setup_logger multiple times won't add many handlers
@
functools
.
lru_cache
()
def
setup_logger
(
output
=
None
,
distributed_rank
=
0
,
*
,
color
=
True
,
name
=
"imagenet"
,
abbrev_name
=
None
):
"""
Initialize the detectron2 logger and set its verbosity level to "INFO".
Args:
output (str): a file name or a directory to save log. If None, will not save log file.
If ends with ".txt" or ".log", assumed to be a file name.
Otherwise, logs will be saved to `output/log.txt`.
name (str): the root module name of this logger
Returns:
logging.Logger: a logger
"""
logger
=
logging
.
getLogger
(
name
)
logger
.
setLevel
(
logging
.
DEBUG
)
logger
.
propagate
=
False
if
abbrev_name
is
None
:
abbrev_name
=
name
plain_formatter
=
logging
.
Formatter
(
"[%(asctime)s.%(msecs)03d]: %(message)s"
,
datefmt
=
"%m/%d %H:%M:%S"
)
# stdout logging: master only
if
distributed_rank
==
0
:
ch
=
logging
.
StreamHandler
(
stream
=
sys
.
stdout
)
ch
.
setLevel
(
logging
.
DEBUG
)
if
color
:
formatter
=
_ColorfulFormatter
(
colored
(
"[%(asctime)s.%(msecs)03d]: "
,
"green"
)
+
"%(message)s"
,
datefmt
=
"%m/%d %H:%M:%S"
,
root_name
=
name
,
abbrev_name
=
str
(
abbrev_name
),
)
else
:
formatter
=
plain_formatter
ch
.
setFormatter
(
formatter
)
logger
.
addHandler
(
ch
)
# file logging: all workers
if
output
is
not
None
:
if
output
.
endswith
(
".txt"
)
or
output
.
endswith
(
".log"
):
filename
=
output
else
:
filename
=
os
.
path
.
join
(
output
,
"log.txt"
)
if
distributed_rank
>
0
:
filename
=
filename
+
f
".rank
{
distributed_rank
}
"
os
.
makedirs
(
os
.
path
.
dirname
(
filename
),
exist_ok
=
True
)
fh
=
logging
.
StreamHandler
(
_cached_log_stream
(
filename
))
fh
.
setLevel
(
logging
.
DEBUG
)
fh
.
setFormatter
(
plain_formatter
)
logger
.
addHandler
(
fh
)
return
logger
# cache the opened file object, so that different calls to `setup_logger`
# with the same file name can safely write to the same file.
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_cached_log_stream
(
filename
):
return
open
(
filename
,
"a"
)
GroundingDINO/groundingdino/util/misc.py
0 → 100644
View file @
ce0e5303
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
Misc functions, including distributed helpers.
Mostly copy-paste from torchvision references.
"""
import
colorsys
import
datetime
import
functools
import
io
import
json
import
os
import
pickle
import
subprocess
import
time
from
collections
import
OrderedDict
,
defaultdict
,
deque
from
typing
import
List
,
Optional
import
numpy
as
np
import
torch
import
torch.distributed
as
dist
# needed due to empty tensor bug in pytorch and torchvision 0.5
import
torchvision
from
torch
import
Tensor
__torchvision_need_compat_flag
=
float
(
torchvision
.
__version__
.
split
(
"."
)[
1
])
<
7
if
__torchvision_need_compat_flag
:
from
torchvision.ops
import
_new_empty_tensor
from
torchvision.ops.misc
import
_output_size
class
SmoothedValue
(
object
):
"""Track a series of values and provide access to smoothed values over a
window or the global series average.
"""
def
__init__
(
self
,
window_size
=
20
,
fmt
=
None
):
if
fmt
is
None
:
fmt
=
"{median:.4f} ({global_avg:.4f})"
self
.
deque
=
deque
(
maxlen
=
window_size
)
self
.
total
=
0.0
self
.
count
=
0
self
.
fmt
=
fmt
def
update
(
self
,
value
,
n
=
1
):
self
.
deque
.
append
(
value
)
self
.
count
+=
n
self
.
total
+=
value
*
n
def
synchronize_between_processes
(
self
):
"""
Warning: does not synchronize the deque!
"""
if
not
is_dist_avail_and_initialized
():
return
t
=
torch
.
tensor
([
self
.
count
,
self
.
total
],
dtype
=
torch
.
float64
,
device
=
"cuda"
)
dist
.
barrier
()
dist
.
all_reduce
(
t
)
t
=
t
.
tolist
()
self
.
count
=
int
(
t
[
0
])
self
.
total
=
t
[
1
]
@
property
def
median
(
self
):
d
=
torch
.
tensor
(
list
(
self
.
deque
))
if
d
.
shape
[
0
]
==
0
:
return
0
return
d
.
median
().
item
()
@
property
def
avg
(
self
):
d
=
torch
.
tensor
(
list
(
self
.
deque
),
dtype
=
torch
.
float32
)
return
d
.
mean
().
item
()
@
property
def
global_avg
(
self
):
if
os
.
environ
.
get
(
"SHILONG_AMP"
,
None
)
==
"1"
:
eps
=
1e-4
else
:
eps
=
1e-6
return
self
.
total
/
(
self
.
count
+
eps
)
@
property
def
max
(
self
):
return
max
(
self
.
deque
)
@
property
def
value
(
self
):
return
self
.
deque
[
-
1
]
def
__str__
(
self
):
return
self
.
fmt
.
format
(
median
=
self
.
median
,
avg
=
self
.
avg
,
global_avg
=
self
.
global_avg
,
max
=
self
.
max
,
value
=
self
.
value
,
)
@
functools
.
lru_cache
()
def
_get_global_gloo_group
():
"""
Return a process group based on gloo backend, containing all the ranks
The result is cached.
"""
if
dist
.
get_backend
()
==
"nccl"
:
return
dist
.
new_group
(
backend
=
"gloo"
)
return
dist
.
group
.
WORLD
def
all_gather_cpu
(
data
):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
world_size
=
get_world_size
()
if
world_size
==
1
:
return
[
data
]
cpu_group
=
_get_global_gloo_group
()
buffer
=
io
.
BytesIO
()
torch
.
save
(
data
,
buffer
)
data_view
=
buffer
.
getbuffer
()
device
=
"cuda"
if
cpu_group
is
None
else
"cpu"
tensor
=
torch
.
ByteTensor
(
data_view
).
to
(
device
)
# obtain Tensor size of each rank
local_size
=
torch
.
tensor
([
tensor
.
numel
()],
device
=
device
,
dtype
=
torch
.
long
)
size_list
=
[
torch
.
tensor
([
0
],
device
=
device
,
dtype
=
torch
.
long
)
for
_
in
range
(
world_size
)]
if
cpu_group
is
None
:
dist
.
all_gather
(
size_list
,
local_size
)
else
:
print
(
"gathering on cpu"
)
dist
.
all_gather
(
size_list
,
local_size
,
group
=
cpu_group
)
size_list
=
[
int
(
size
.
item
())
for
size
in
size_list
]
max_size
=
max
(
size_list
)
assert
isinstance
(
local_size
.
item
(),
int
)
local_size
=
int
(
local_size
.
item
())
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list
=
[]
for
_
in
size_list
:
tensor_list
.
append
(
torch
.
empty
((
max_size
,),
dtype
=
torch
.
uint8
,
device
=
device
))
if
local_size
!=
max_size
:
padding
=
torch
.
empty
(
size
=
(
max_size
-
local_size
,),
dtype
=
torch
.
uint8
,
device
=
device
)
tensor
=
torch
.
cat
((
tensor
,
padding
),
dim
=
0
)
if
cpu_group
is
None
:
dist
.
all_gather
(
tensor_list
,
tensor
)
else
:
dist
.
all_gather
(
tensor_list
,
tensor
,
group
=
cpu_group
)
data_list
=
[]
for
size
,
tensor
in
zip
(
size_list
,
tensor_list
):
tensor
=
torch
.
split
(
tensor
,
[
size
,
max_size
-
size
],
dim
=
0
)[
0
]
buffer
=
io
.
BytesIO
(
tensor
.
cpu
().
numpy
())
obj
=
torch
.
load
(
buffer
)
data_list
.
append
(
obj
)
return
data_list
def
all_gather
(
data
):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors)
Args:
data: any picklable object
Returns:
list[data]: list of data gathered from each rank
"""
if
os
.
getenv
(
"CPU_REDUCE"
)
==
"1"
:
return
all_gather_cpu
(
data
)
world_size
=
get_world_size
()
if
world_size
==
1
:
return
[
data
]
# serialized to a Tensor
buffer
=
pickle
.
dumps
(
data
)
storage
=
torch
.
ByteStorage
.
from_buffer
(
buffer
)
tensor
=
torch
.
ByteTensor
(
storage
).
to
(
"cuda"
)
# obtain Tensor size of each rank
local_size
=
torch
.
tensor
([
tensor
.
numel
()],
device
=
"cuda"
)
size_list
=
[
torch
.
tensor
([
0
],
device
=
"cuda"
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
size_list
,
local_size
)
size_list
=
[
int
(
size
.
item
())
for
size
in
size_list
]
max_size
=
max
(
size_list
)
# receiving Tensor from all ranks
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
tensor_list
=
[]
for
_
in
size_list
:
tensor_list
.
append
(
torch
.
empty
((
max_size
,),
dtype
=
torch
.
uint8
,
device
=
"cuda"
))
if
local_size
!=
max_size
:
padding
=
torch
.
empty
(
size
=
(
max_size
-
local_size
,),
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
tensor
=
torch
.
cat
((
tensor
,
padding
),
dim
=
0
)
dist
.
all_gather
(
tensor_list
,
tensor
)
data_list
=
[]
for
size
,
tensor
in
zip
(
size_list
,
tensor_list
):
buffer
=
tensor
.
cpu
().
numpy
().
tobytes
()[:
size
]
data_list
.
append
(
pickle
.
loads
(
buffer
))
return
data_list
def
reduce_dict
(
input_dict
,
average
=
True
):
"""
Args:
input_dict (dict): all the values will be reduced
average (bool): whether to do average or sum
Reduce the values in the dictionary from all processes so that all processes
have the averaged results. Returns a dict with the same fields as
input_dict, after reduction.
"""
world_size
=
get_world_size
()
if
world_size
<
2
:
return
input_dict
with
torch
.
no_grad
():
names
=
[]
values
=
[]
# sort the keys so that they are consistent across processes
for
k
in
sorted
(
input_dict
.
keys
()):
names
.
append
(
k
)
values
.
append
(
input_dict
[
k
])
values
=
torch
.
stack
(
values
,
dim
=
0
)
dist
.
all_reduce
(
values
)
if
average
:
values
/=
world_size
reduced_dict
=
{
k
:
v
for
k
,
v
in
zip
(
names
,
values
)}
return
reduced_dict
class
MetricLogger
(
object
):
def
__init__
(
self
,
delimiter
=
"
\t
"
):
self
.
meters
=
defaultdict
(
SmoothedValue
)
self
.
delimiter
=
delimiter
def
update
(
self
,
**
kwargs
):
for
k
,
v
in
kwargs
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
v
=
v
.
item
()
assert
isinstance
(
v
,
(
float
,
int
))
self
.
meters
[
k
].
update
(
v
)
def
__getattr__
(
self
,
attr
):
if
attr
in
self
.
meters
:
return
self
.
meters
[
attr
]
if
attr
in
self
.
__dict__
:
return
self
.
__dict__
[
attr
]
raise
AttributeError
(
"'{}' object has no attribute '{}'"
.
format
(
type
(
self
).
__name__
,
attr
))
def
__str__
(
self
):
loss_str
=
[]
for
name
,
meter
in
self
.
meters
.
items
():
# print(name, str(meter))
# import ipdb;ipdb.set_trace()
if
meter
.
count
>
0
:
loss_str
.
append
(
"{}: {}"
.
format
(
name
,
str
(
meter
)))
return
self
.
delimiter
.
join
(
loss_str
)
def
synchronize_between_processes
(
self
):
for
meter
in
self
.
meters
.
values
():
meter
.
synchronize_between_processes
()
def
add_meter
(
self
,
name
,
meter
):
self
.
meters
[
name
]
=
meter
def
log_every
(
self
,
iterable
,
print_freq
,
header
=
None
,
logger
=
None
):
if
logger
is
None
:
print_func
=
print
else
:
print_func
=
logger
.
info
i
=
0
if
not
header
:
header
=
""
start_time
=
time
.
time
()
end
=
time
.
time
()
iter_time
=
SmoothedValue
(
fmt
=
"{avg:.4f}"
)
data_time
=
SmoothedValue
(
fmt
=
"{avg:.4f}"
)
space_fmt
=
":"
+
str
(
len
(
str
(
len
(
iterable
))))
+
"d"
if
torch
.
cuda
.
is_available
():
log_msg
=
self
.
delimiter
.
join
(
[
header
,
"[{0"
+
space_fmt
+
"}/{1}]"
,
"eta: {eta}"
,
"{meters}"
,
"time: {time}"
,
"data: {data}"
,
"max mem: {memory:.0f}"
,
]
)
else
:
log_msg
=
self
.
delimiter
.
join
(
[
header
,
"[{0"
+
space_fmt
+
"}/{1}]"
,
"eta: {eta}"
,
"{meters}"
,
"time: {time}"
,
"data: {data}"
,
]
)
MB
=
1024.0
*
1024.0
for
obj
in
iterable
:
data_time
.
update
(
time
.
time
()
-
end
)
yield
obj
# import ipdb; ipdb.set_trace()
iter_time
.
update
(
time
.
time
()
-
end
)
if
i
%
print_freq
==
0
or
i
==
len
(
iterable
)
-
1
:
eta_seconds
=
iter_time
.
global_avg
*
(
len
(
iterable
)
-
i
)
eta_string
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
eta_seconds
)))
if
torch
.
cuda
.
is_available
():
print_func
(
log_msg
.
format
(
i
,
len
(
iterable
),
eta
=
eta_string
,
meters
=
str
(
self
),
time
=
str
(
iter_time
),
data
=
str
(
data_time
),
memory
=
torch
.
cuda
.
max_memory_allocated
()
/
MB
,
)
)
else
:
print_func
(
log_msg
.
format
(
i
,
len
(
iterable
),
eta
=
eta_string
,
meters
=
str
(
self
),
time
=
str
(
iter_time
),
data
=
str
(
data_time
),
)
)
i
+=
1
end
=
time
.
time
()
total_time
=
time
.
time
()
-
start_time
total_time_str
=
str
(
datetime
.
timedelta
(
seconds
=
int
(
total_time
)))
print_func
(
"{} Total time: {} ({:.4f} s / it)"
.
format
(
header
,
total_time_str
,
total_time
/
len
(
iterable
)
)
)
def
get_sha
():
cwd
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
def
_run
(
command
):
return
subprocess
.
check_output
(
command
,
cwd
=
cwd
).
decode
(
"ascii"
).
strip
()
sha
=
"N/A"
diff
=
"clean"
branch
=
"N/A"
try
:
sha
=
_run
([
"git"
,
"rev-parse"
,
"HEAD"
])
subprocess
.
check_output
([
"git"
,
"diff"
],
cwd
=
cwd
)
diff
=
_run
([
"git"
,
"diff-index"
,
"HEAD"
])
diff
=
"has uncommited changes"
if
diff
else
"clean"
branch
=
_run
([
"git"
,
"rev-parse"
,
"--abbrev-ref"
,
"HEAD"
])
except
Exception
:
pass
message
=
f
"sha:
{
sha
}
, status:
{
diff
}
, branch:
{
branch
}
"
return
message
def
collate_fn
(
batch
):
# import ipdb; ipdb.set_trace()
batch
=
list
(
zip
(
*
batch
))
batch
[
0
]
=
nested_tensor_from_tensor_list
(
batch
[
0
])
return
tuple
(
batch
)
def
_max_by_axis
(
the_list
):
# type: (List[List[int]]) -> List[int]
maxes
=
the_list
[
0
]
for
sublist
in
the_list
[
1
:]:
for
index
,
item
in
enumerate
(
sublist
):
maxes
[
index
]
=
max
(
maxes
[
index
],
item
)
return
maxes
class
NestedTensor
(
object
):
def
__init__
(
self
,
tensors
,
mask
:
Optional
[
Tensor
]):
self
.
tensors
=
tensors
self
.
mask
=
mask
if
mask
==
"auto"
:
self
.
mask
=
torch
.
zeros_like
(
tensors
).
to
(
tensors
.
device
)
if
self
.
mask
.
dim
()
==
3
:
self
.
mask
=
self
.
mask
.
sum
(
0
).
to
(
bool
)
elif
self
.
mask
.
dim
()
==
4
:
self
.
mask
=
self
.
mask
.
sum
(
1
).
to
(
bool
)
else
:
raise
ValueError
(
"tensors dim must be 3 or 4 but {}({})"
.
format
(
self
.
tensors
.
dim
(),
self
.
tensors
.
shape
)
)
def
imgsize
(
self
):
res
=
[]
for
i
in
range
(
self
.
tensors
.
shape
[
0
]):
mask
=
self
.
mask
[
i
]
maxH
=
(
~
mask
).
sum
(
0
).
max
()
maxW
=
(
~
mask
).
sum
(
1
).
max
()
res
.
append
(
torch
.
Tensor
([
maxH
,
maxW
]))
return
res
def
to
(
self
,
device
):
# type: (Device) -> NestedTensor # noqa
cast_tensor
=
self
.
tensors
.
to
(
device
)
mask
=
self
.
mask
if
mask
is
not
None
:
assert
mask
is
not
None
cast_mask
=
mask
.
to
(
device
)
else
:
cast_mask
=
None
return
NestedTensor
(
cast_tensor
,
cast_mask
)
def
to_img_list_single
(
self
,
tensor
,
mask
):
assert
tensor
.
dim
()
==
3
,
"dim of tensor should be 3 but {}"
.
format
(
tensor
.
dim
())
maxH
=
(
~
mask
).
sum
(
0
).
max
()
maxW
=
(
~
mask
).
sum
(
1
).
max
()
img
=
tensor
[:,
:
maxH
,
:
maxW
]
return
img
def
to_img_list
(
self
):
"""remove the padding and convert to img list
Returns:
[type]: [description]
"""
if
self
.
tensors
.
dim
()
==
3
:
return
self
.
to_img_list_single
(
self
.
tensors
,
self
.
mask
)
else
:
res
=
[]
for
i
in
range
(
self
.
tensors
.
shape
[
0
]):
tensor_i
=
self
.
tensors
[
i
]
mask_i
=
self
.
mask
[
i
]
res
.
append
(
self
.
to_img_list_single
(
tensor_i
,
mask_i
))
return
res
@
property
def
device
(
self
):
return
self
.
tensors
.
device
def
decompose
(
self
):
return
self
.
tensors
,
self
.
mask
def
__repr__
(
self
):
return
str
(
self
.
tensors
)
@
property
def
shape
(
self
):
return
{
"tensors.shape"
:
self
.
tensors
.
shape
,
"mask.shape"
:
self
.
mask
.
shape
}
def
nested_tensor_from_tensor_list
(
tensor_list
:
List
[
Tensor
]):
# TODO make this more general
if
tensor_list
[
0
].
ndim
==
3
:
if
torchvision
.
_is_tracing
():
# nested_tensor_from_tensor_list() does not export well to ONNX
# call _onnx_nested_tensor_from_tensor_list() instead
return
_onnx_nested_tensor_from_tensor_list
(
tensor_list
)
# TODO make it support different-sized images
max_size
=
_max_by_axis
([
list
(
img
.
shape
)
for
img
in
tensor_list
])
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
batch_shape
=
[
len
(
tensor_list
)]
+
max_size
b
,
c
,
h
,
w
=
batch_shape
dtype
=
tensor_list
[
0
].
dtype
device
=
tensor_list
[
0
].
device
tensor
=
torch
.
zeros
(
batch_shape
,
dtype
=
dtype
,
device
=
device
)
mask
=
torch
.
ones
((
b
,
h
,
w
),
dtype
=
torch
.
bool
,
device
=
device
)
for
img
,
pad_img
,
m
in
zip
(
tensor_list
,
tensor
,
mask
):
pad_img
[:
img
.
shape
[
0
],
:
img
.
shape
[
1
],
:
img
.
shape
[
2
]].
copy_
(
img
)
m
[:
img
.
shape
[
1
],
:
img
.
shape
[
2
]]
=
False
else
:
raise
ValueError
(
"not supported"
)
return
NestedTensor
(
tensor
,
mask
)
# _onnx_nested_tensor_from_tensor_list() is an implementation of
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
@
torch
.
jit
.
unused
def
_onnx_nested_tensor_from_tensor_list
(
tensor_list
:
List
[
Tensor
])
->
NestedTensor
:
max_size
=
[]
for
i
in
range
(
tensor_list
[
0
].
dim
()):
max_size_i
=
torch
.
max
(
torch
.
stack
([
img
.
shape
[
i
]
for
img
in
tensor_list
]).
to
(
torch
.
float32
)
).
to
(
torch
.
int64
)
max_size
.
append
(
max_size_i
)
max_size
=
tuple
(
max_size
)
# work around for
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
# m[: img.shape[1], :img.shape[2]] = False
# which is not yet supported in onnx
padded_imgs
=
[]
padded_masks
=
[]
for
img
in
tensor_list
:
padding
=
[(
s1
-
s2
)
for
s1
,
s2
in
zip
(
max_size
,
tuple
(
img
.
shape
))]
padded_img
=
torch
.
nn
.
functional
.
pad
(
img
,
(
0
,
padding
[
2
],
0
,
padding
[
1
],
0
,
padding
[
0
]))
padded_imgs
.
append
(
padded_img
)
m
=
torch
.
zeros_like
(
img
[
0
],
dtype
=
torch
.
int
,
device
=
img
.
device
)
padded_mask
=
torch
.
nn
.
functional
.
pad
(
m
,
(
0
,
padding
[
2
],
0
,
padding
[
1
]),
"constant"
,
1
)
padded_masks
.
append
(
padded_mask
.
to
(
torch
.
bool
))
tensor
=
torch
.
stack
(
padded_imgs
)
mask
=
torch
.
stack
(
padded_masks
)
return
NestedTensor
(
tensor
,
mask
=
mask
)
def
setup_for_distributed
(
is_master
):
"""
This function disables printing when not in master process
"""
import
builtins
as
__builtin__
builtin_print
=
__builtin__
.
print
def
print
(
*
args
,
**
kwargs
):
force
=
kwargs
.
pop
(
"force"
,
False
)
if
is_master
or
force
:
builtin_print
(
*
args
,
**
kwargs
)
__builtin__
.
print
=
print
def
is_dist_avail_and_initialized
():
if
not
dist
.
is_available
():
return
False
if
not
dist
.
is_initialized
():
return
False
return
True
def
get_world_size
():
if
not
is_dist_avail_and_initialized
():
return
1
return
dist
.
get_world_size
()
def
get_rank
():
if
not
is_dist_avail_and_initialized
():
return
0
return
dist
.
get_rank
()
def
is_main_process
():
return
get_rank
()
==
0
def
save_on_master
(
*
args
,
**
kwargs
):
if
is_main_process
():
torch
.
save
(
*
args
,
**
kwargs
)
def
init_distributed_mode
(
args
):
if
"WORLD_SIZE"
in
os
.
environ
and
os
.
environ
[
"WORLD_SIZE"
]
!=
""
:
# 'RANK' in os.environ and
args
.
rank
=
int
(
os
.
environ
[
"RANK"
])
args
.
world_size
=
int
(
os
.
environ
[
"WORLD_SIZE"
])
args
.
gpu
=
args
.
local_rank
=
int
(
os
.
environ
[
"LOCAL_RANK"
])
# launch by torch.distributed.launch
# Single node
# python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 1 --rank 0 ...
# Multi nodes
# python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 0 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
# python -m torch.distributed.launch --nproc_per_node=8 main.py --world-size 2 --rank 1 --dist-url 'tcp://IP_OF_NODE0:FREEPORT' ...
# args.rank = int(os.environ.get('OMPI_COMM_WORLD_RANK'))
# local_world_size = int(os.environ['GPU_PER_NODE_COUNT'])
# args.world_size = args.world_size * local_world_size
# args.gpu = args.local_rank = int(os.environ['LOCAL_RANK'])
# args.rank = args.rank * local_world_size + args.local_rank
print
(
"world size: {}, rank: {}, local rank: {}"
.
format
(
args
.
world_size
,
args
.
rank
,
args
.
local_rank
)
)
print
(
json
.
dumps
(
dict
(
os
.
environ
),
indent
=
2
))
elif
"SLURM_PROCID"
in
os
.
environ
:
args
.
rank
=
int
(
os
.
environ
[
"SLURM_PROCID"
])
args
.
gpu
=
args
.
local_rank
=
int
(
os
.
environ
[
"SLURM_LOCALID"
])
args
.
world_size
=
int
(
os
.
environ
[
"SLURM_NPROCS"
])
print
(
"world size: {}, world rank: {}, local rank: {}, device_count: {}"
.
format
(
args
.
world_size
,
args
.
rank
,
args
.
local_rank
,
torch
.
cuda
.
device_count
()
)
)
else
:
print
(
"Not using distributed mode"
)
args
.
distributed
=
False
args
.
world_size
=
1
args
.
rank
=
0
args
.
local_rank
=
0
return
print
(
"world_size:{} rank:{} local_rank:{}"
.
format
(
args
.
world_size
,
args
.
rank
,
args
.
local_rank
))
args
.
distributed
=
True
torch
.
cuda
.
set_device
(
args
.
local_rank
)
args
.
dist_backend
=
"nccl"
print
(
"| distributed init (rank {}): {}"
.
format
(
args
.
rank
,
args
.
dist_url
),
flush
=
True
)
torch
.
distributed
.
init_process_group
(
backend
=
args
.
dist_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
init_method
=
args
.
dist_url
,
)
print
(
"Before torch.distributed.barrier()"
)
torch
.
distributed
.
barrier
()
print
(
"End torch.distributed.barrier()"
)
setup_for_distributed
(
args
.
rank
==
0
)
@
torch
.
no_grad
()
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
"""Computes the precision@k for the specified values of k"""
if
target
.
numel
()
==
0
:
return
[
torch
.
zeros
([],
device
=
output
.
device
)]
maxk
=
max
(
topk
)
batch_size
=
target
.
size
(
0
)
_
,
pred
=
output
.
topk
(
maxk
,
1
,
True
,
True
)
pred
=
pred
.
t
()
correct
=
pred
.
eq
(
target
.
view
(
1
,
-
1
).
expand_as
(
pred
))
res
=
[]
for
k
in
topk
:
correct_k
=
correct
[:
k
].
view
(
-
1
).
float
().
sum
(
0
)
res
.
append
(
correct_k
.
mul_
(
100.0
/
batch_size
))
return
res
@
torch
.
no_grad
()
def
accuracy_onehot
(
pred
,
gt
):
"""_summary_
Args:
pred (_type_): n, c
gt (_type_): n, c
"""
tp
=
((
pred
-
gt
).
abs
().
sum
(
-
1
)
<
1e-4
).
float
().
sum
()
acc
=
tp
/
gt
.
shape
[
0
]
*
100
return
acc
def
interpolate
(
input
,
size
=
None
,
scale_factor
=
None
,
mode
=
"nearest"
,
align_corners
=
None
):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
This will eventually be supported natively by PyTorch, and this
class can go away.
"""
if
__torchvision_need_compat_flag
<
0.7
:
if
input
.
numel
()
>
0
:
return
torch
.
nn
.
functional
.
interpolate
(
input
,
size
,
scale_factor
,
mode
,
align_corners
)
output_shape
=
_output_size
(
2
,
input
,
size
,
scale_factor
)
output_shape
=
list
(
input
.
shape
[:
-
2
])
+
list
(
output_shape
)
return
_new_empty_tensor
(
input
,
output_shape
)
else
:
return
torchvision
.
ops
.
misc
.
interpolate
(
input
,
size
,
scale_factor
,
mode
,
align_corners
)
class
color_sys
:
def
__init__
(
self
,
num_colors
)
->
None
:
self
.
num_colors
=
num_colors
colors
=
[]
for
i
in
np
.
arange
(
0.0
,
360.0
,
360.0
/
num_colors
):
hue
=
i
/
360.0
lightness
=
(
50
+
np
.
random
.
rand
()
*
10
)
/
100.0
saturation
=
(
90
+
np
.
random
.
rand
()
*
10
)
/
100.0
colors
.
append
(
tuple
([
int
(
j
*
255
)
for
j
in
colorsys
.
hls_to_rgb
(
hue
,
lightness
,
saturation
)])
)
self
.
colors
=
colors
def
__call__
(
self
,
idx
):
return
self
.
colors
[
idx
]
def
inverse_sigmoid
(
x
,
eps
=
1e-3
):
x
=
x
.
clamp
(
min
=
0
,
max
=
1
)
x1
=
x
.
clamp
(
min
=
eps
)
x2
=
(
1
-
x
).
clamp
(
min
=
eps
)
return
torch
.
log
(
x1
/
x2
)
def
clean_state_dict
(
state_dict
):
new_state_dict
=
OrderedDict
()
for
k
,
v
in
state_dict
.
items
():
if
k
[:
7
]
==
"module."
:
k
=
k
[
7
:]
# remove `module.`
new_state_dict
[
k
]
=
v
return
new_state_dict
GroundingDINO/groundingdino/util/slconfig.py
0 → 100644
View file @
ce0e5303
# ==========================================================
# Modified from mmcv
# ==========================================================
import
ast
import
os.path
as
osp
import
shutil
import
sys
import
tempfile
from
argparse
import
Action
from
importlib
import
import_module
import
platform
from
addict
import
Dict
from
yapf.yapflib.yapf_api
import
FormatCode
BASE_KEY
=
"_base_"
DELETE_KEY
=
"_delete_"
RESERVED_KEYS
=
[
"filename"
,
"text"
,
"pretty_text"
,
"get"
,
"dump"
,
"merge_from_dict"
]
def
check_file_exist
(
filename
,
msg_tmpl
=
'file "{}" does not exist'
):
if
not
osp
.
isfile
(
filename
):
raise
FileNotFoundError
(
msg_tmpl
.
format
(
filename
))
class
ConfigDict
(
Dict
):
def
__missing__
(
self
,
name
):
raise
KeyError
(
name
)
def
__getattr__
(
self
,
name
):
try
:
value
=
super
(
ConfigDict
,
self
).
__getattr__
(
name
)
except
KeyError
:
ex
=
AttributeError
(
f
"'
{
self
.
__class__
.
__name__
}
' object has no "
f
"attribute '
{
name
}
'"
)
except
Exception
as
e
:
ex
=
e
else
:
return
value
raise
ex
class
SLConfig
(
object
):
"""
config files.
only support .py file as config now.
ref: mmcv.utils.config
Example:
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
>>> cfg.a
1
>>> cfg.b
{'b1': [0, 1]}
>>> cfg.b.b1
[0, 1]
>>> cfg = Config.fromfile('tests/data/config/a.py')
>>> cfg.filename
"/home/kchen/projects/mmcv/tests/data/config/a.py"
>>> cfg.item4
'test'
>>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/a.py]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
"""
@
staticmethod
def
_validate_py_syntax
(
filename
):
with
open
(
filename
)
as
f
:
content
=
f
.
read
()
try
:
ast
.
parse
(
content
)
except
SyntaxError
:
raise
SyntaxError
(
"There are syntax errors in config "
f
"file
{
filename
}
"
)
@
staticmethod
def
_file2dict
(
filename
):
filename
=
osp
.
abspath
(
osp
.
expanduser
(
filename
))
check_file_exist
(
filename
)
if
filename
.
lower
().
endswith
(
".py"
):
with
tempfile
.
TemporaryDirectory
()
as
temp_config_dir
:
temp_config_file
=
tempfile
.
NamedTemporaryFile
(
dir
=
temp_config_dir
,
suffix
=
".py"
)
temp_config_name
=
osp
.
basename
(
temp_config_file
.
name
)
if
platform
.
system
()
==
'Windows'
:
temp_config_file
.
close
()
shutil
.
copyfile
(
filename
,
osp
.
join
(
temp_config_dir
,
temp_config_name
))
temp_module_name
=
osp
.
splitext
(
temp_config_name
)[
0
]
sys
.
path
.
insert
(
0
,
temp_config_dir
)
SLConfig
.
_validate_py_syntax
(
filename
)
mod
=
import_module
(
temp_module_name
)
sys
.
path
.
pop
(
0
)
cfg_dict
=
{
name
:
value
for
name
,
value
in
mod
.
__dict__
.
items
()
if
not
name
.
startswith
(
"__"
)
}
# delete imported module
del
sys
.
modules
[
temp_module_name
]
# close temp file
temp_config_file
.
close
()
elif
filename
.
lower
().
endswith
((
".yml"
,
".yaml"
,
".json"
)):
from
.slio
import
slload
cfg_dict
=
slload
(
filename
)
else
:
raise
IOError
(
"Only py/yml/yaml/json type are supported now!"
)
cfg_text
=
filename
+
"
\n
"
with
open
(
filename
,
"r"
)
as
f
:
cfg_text
+=
f
.
read
()
# parse the base file
if
BASE_KEY
in
cfg_dict
:
cfg_dir
=
osp
.
dirname
(
filename
)
base_filename
=
cfg_dict
.
pop
(
BASE_KEY
)
base_filename
=
base_filename
if
isinstance
(
base_filename
,
list
)
else
[
base_filename
]
cfg_dict_list
=
list
()
cfg_text_list
=
list
()
for
f
in
base_filename
:
_cfg_dict
,
_cfg_text
=
SLConfig
.
_file2dict
(
osp
.
join
(
cfg_dir
,
f
))
cfg_dict_list
.
append
(
_cfg_dict
)
cfg_text_list
.
append
(
_cfg_text
)
base_cfg_dict
=
dict
()
for
c
in
cfg_dict_list
:
if
len
(
base_cfg_dict
.
keys
()
&
c
.
keys
())
>
0
:
raise
KeyError
(
"Duplicate key is not allowed among bases"
)
# TODO Allow the duplicate key while warnning user
base_cfg_dict
.
update
(
c
)
base_cfg_dict
=
SLConfig
.
_merge_a_into_b
(
cfg_dict
,
base_cfg_dict
)
cfg_dict
=
base_cfg_dict
# merge cfg_text
cfg_text_list
.
append
(
cfg_text
)
cfg_text
=
"
\n
"
.
join
(
cfg_text_list
)
return
cfg_dict
,
cfg_text
@
staticmethod
def
_merge_a_into_b
(
a
,
b
):
"""merge dict `a` into dict `b` (non-inplace).
values in `a` will overwrite `b`.
copy first to avoid inplace modification
Args:
a ([type]): [description]
b ([type]): [description]
Returns:
[dict]: [description]
"""
# import ipdb; ipdb.set_trace()
if
not
isinstance
(
a
,
dict
):
return
a
b
=
b
.
copy
()
for
k
,
v
in
a
.
items
():
if
isinstance
(
v
,
dict
)
and
k
in
b
and
not
v
.
pop
(
DELETE_KEY
,
False
):
if
not
isinstance
(
b
[
k
],
dict
)
and
not
isinstance
(
b
[
k
],
list
):
# if :
# import ipdb; ipdb.set_trace()
raise
TypeError
(
f
"
{
k
}
=
{
v
}
in child config cannot inherit from base "
f
"because
{
k
}
is a dict in the child config but is of "
f
"type
{
type
(
b
[
k
])
}
in base config. You may set "
f
"`
{
DELETE_KEY
}
=True` to ignore the base config"
)
b
[
k
]
=
SLConfig
.
_merge_a_into_b
(
v
,
b
[
k
])
elif
isinstance
(
b
,
list
):
try
:
_
=
int
(
k
)
except
:
raise
TypeError
(
f
"b is a list, "
f
"index
{
k
}
should be an int when input but
{
type
(
k
)
}
"
)
b
[
int
(
k
)]
=
SLConfig
.
_merge_a_into_b
(
v
,
b
[
int
(
k
)])
else
:
b
[
k
]
=
v
return
b
@
staticmethod
def
fromfile
(
filename
):
cfg_dict
,
cfg_text
=
SLConfig
.
_file2dict
(
filename
)
return
SLConfig
(
cfg_dict
,
cfg_text
=
cfg_text
,
filename
=
filename
)
def
__init__
(
self
,
cfg_dict
=
None
,
cfg_text
=
None
,
filename
=
None
):
if
cfg_dict
is
None
:
cfg_dict
=
dict
()
elif
not
isinstance
(
cfg_dict
,
dict
):
raise
TypeError
(
"cfg_dict must be a dict, but "
f
"got
{
type
(
cfg_dict
)
}
"
)
for
key
in
cfg_dict
:
if
key
in
RESERVED_KEYS
:
raise
KeyError
(
f
"
{
key
}
is reserved for config file"
)
super
(
SLConfig
,
self
).
__setattr__
(
"_cfg_dict"
,
ConfigDict
(
cfg_dict
))
super
(
SLConfig
,
self
).
__setattr__
(
"_filename"
,
filename
)
if
cfg_text
:
text
=
cfg_text
elif
filename
:
with
open
(
filename
,
"r"
)
as
f
:
text
=
f
.
read
()
else
:
text
=
""
super
(
SLConfig
,
self
).
__setattr__
(
"_text"
,
text
)
@
property
def
filename
(
self
):
return
self
.
_filename
@
property
def
text
(
self
):
return
self
.
_text
@
property
def
pretty_text
(
self
):
indent
=
4
def
_indent
(
s_
,
num_spaces
):
s
=
s_
.
split
(
"
\n
"
)
if
len
(
s
)
==
1
:
return
s_
first
=
s
.
pop
(
0
)
s
=
[(
num_spaces
*
" "
)
+
line
for
line
in
s
]
s
=
"
\n
"
.
join
(
s
)
s
=
first
+
"
\n
"
+
s
return
s
def
_format_basic_types
(
k
,
v
,
use_mapping
=
False
):
if
isinstance
(
v
,
str
):
v_str
=
f
"'
{
v
}
'"
else
:
v_str
=
str
(
v
)
if
use_mapping
:
k_str
=
f
"'
{
k
}
'"
if
isinstance
(
k
,
str
)
else
str
(
k
)
attr_str
=
f
"
{
k_str
}
:
{
v_str
}
"
else
:
attr_str
=
f
"
{
str
(
k
)
}
=
{
v_str
}
"
attr_str
=
_indent
(
attr_str
,
indent
)
return
attr_str
def
_format_list
(
k
,
v
,
use_mapping
=
False
):
# check if all items in the list are dict
if
all
(
isinstance
(
_
,
dict
)
for
_
in
v
):
v_str
=
"[
\n
"
v_str
+=
"
\n
"
.
join
(
f
"dict(
{
_indent
(
_format_dict
(
v_
),
indent
)
}
),"
for
v_
in
v
).
rstrip
(
","
)
if
use_mapping
:
k_str
=
f
"'
{
k
}
'"
if
isinstance
(
k
,
str
)
else
str
(
k
)
attr_str
=
f
"
{
k_str
}
:
{
v_str
}
"
else
:
attr_str
=
f
"
{
str
(
k
)
}
=
{
v_str
}
"
attr_str
=
_indent
(
attr_str
,
indent
)
+
"]"
else
:
attr_str
=
_format_basic_types
(
k
,
v
,
use_mapping
)
return
attr_str
def
_contain_invalid_identifier
(
dict_str
):
contain_invalid_identifier
=
False
for
key_name
in
dict_str
:
contain_invalid_identifier
|=
not
str
(
key_name
).
isidentifier
()
return
contain_invalid_identifier
def
_format_dict
(
input_dict
,
outest_level
=
False
):
r
=
""
s
=
[]
use_mapping
=
_contain_invalid_identifier
(
input_dict
)
if
use_mapping
:
r
+=
"{"
for
idx
,
(
k
,
v
)
in
enumerate
(
input_dict
.
items
()):
is_last
=
idx
>=
len
(
input_dict
)
-
1
end
=
""
if
outest_level
or
is_last
else
","
if
isinstance
(
v
,
dict
):
v_str
=
"
\n
"
+
_format_dict
(
v
)
if
use_mapping
:
k_str
=
f
"'
{
k
}
'"
if
isinstance
(
k
,
str
)
else
str
(
k
)
attr_str
=
f
"
{
k_str
}
: dict(
{
v_str
}
"
else
:
attr_str
=
f
"
{
str
(
k
)
}
=dict(
{
v_str
}
"
attr_str
=
_indent
(
attr_str
,
indent
)
+
")"
+
end
elif
isinstance
(
v
,
list
):
attr_str
=
_format_list
(
k
,
v
,
use_mapping
)
+
end
else
:
attr_str
=
_format_basic_types
(
k
,
v
,
use_mapping
)
+
end
s
.
append
(
attr_str
)
r
+=
"
\n
"
.
join
(
s
)
if
use_mapping
:
r
+=
"}"
return
r
cfg_dict
=
self
.
_cfg_dict
.
to_dict
()
text
=
_format_dict
(
cfg_dict
,
outest_level
=
True
)
# copied from setup.cfg
yapf_style
=
dict
(
based_on_style
=
"pep8"
,
blank_line_before_nested_class_or_def
=
True
,
split_before_expression_after_opening_paren
=
True
,
)
text
,
_
=
FormatCode
(
text
,
style_config
=
yapf_style
,
verify
=
True
)
return
text
def
__repr__
(
self
):
return
f
"Config (path:
{
self
.
filename
}
):
{
self
.
_cfg_dict
.
__repr__
()
}
"
def
__len__
(
self
):
return
len
(
self
.
_cfg_dict
)
def
__getattr__
(
self
,
name
):
# # debug
# print('+'*15)
# print('name=%s' % name)
# print("addr:", id(self))
# # print('type(self):', type(self))
# print(self.__dict__)
# print('+'*15)
# if self.__dict__ == {}:
# raise ValueError
return
getattr
(
self
.
_cfg_dict
,
name
)
def
__getitem__
(
self
,
name
):
return
self
.
_cfg_dict
.
__getitem__
(
name
)
def
__setattr__
(
self
,
name
,
value
):
if
isinstance
(
value
,
dict
):
value
=
ConfigDict
(
value
)
self
.
_cfg_dict
.
__setattr__
(
name
,
value
)
def
__setitem__
(
self
,
name
,
value
):
if
isinstance
(
value
,
dict
):
value
=
ConfigDict
(
value
)
self
.
_cfg_dict
.
__setitem__
(
name
,
value
)
def
__iter__
(
self
):
return
iter
(
self
.
_cfg_dict
)
def
dump
(
self
,
file
=
None
):
# import ipdb; ipdb.set_trace()
if
file
is
None
:
return
self
.
pretty_text
else
:
with
open
(
file
,
"w"
)
as
f
:
f
.
write
(
self
.
pretty_text
)
def
merge_from_dict
(
self
,
options
):
"""Merge list into cfg_dict
Merge the dict parsed by MultipleKVAction into this cfg.
Examples:
>>> options = {'model.backbone.depth': 50,
... 'model.backbone.with_cp':True}
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
>>> cfg.merge_from_dict(options)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(
... model=dict(backbone=dict(depth=50, with_cp=True)))
Args:
options (dict): dict of configs to merge from.
"""
option_cfg_dict
=
{}
for
full_key
,
v
in
options
.
items
():
d
=
option_cfg_dict
key_list
=
full_key
.
split
(
"."
)
for
subkey
in
key_list
[:
-
1
]:
d
.
setdefault
(
subkey
,
ConfigDict
())
d
=
d
[
subkey
]
subkey
=
key_list
[
-
1
]
d
[
subkey
]
=
v
cfg_dict
=
super
(
SLConfig
,
self
).
__getattribute__
(
"_cfg_dict"
)
super
(
SLConfig
,
self
).
__setattr__
(
"_cfg_dict"
,
SLConfig
.
_merge_a_into_b
(
option_cfg_dict
,
cfg_dict
)
)
# for multiprocess
def
__setstate__
(
self
,
state
):
self
.
__init__
(
state
)
def
copy
(
self
):
return
SLConfig
(
self
.
_cfg_dict
.
copy
())
def
deepcopy
(
self
):
return
SLConfig
(
self
.
_cfg_dict
.
deepcopy
())
class
DictAction
(
Action
):
"""
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options should
be passed as comma separated values, i.e KEY=V1,V2,V3
"""
@
staticmethod
def
_parse_int_float_bool
(
val
):
try
:
return
int
(
val
)
except
ValueError
:
pass
try
:
return
float
(
val
)
except
ValueError
:
pass
if
val
.
lower
()
in
[
"true"
,
"false"
]:
return
True
if
val
.
lower
()
==
"true"
else
False
if
val
.
lower
()
in
[
"none"
,
"null"
]:
return
None
return
val
def
__call__
(
self
,
parser
,
namespace
,
values
,
option_string
=
None
):
options
=
{}
for
kv
in
values
:
key
,
val
=
kv
.
split
(
"="
,
maxsplit
=
1
)
val
=
[
self
.
_parse_int_float_bool
(
v
)
for
v
in
val
.
split
(
","
)]
if
len
(
val
)
==
1
:
val
=
val
[
0
]
options
[
key
]
=
val
setattr
(
namespace
,
self
.
dest
,
options
)
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment