Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c81ebd1c
Unverified
Commit
c81ebd1c
authored
Sep 20, 2022
by
Alara Dirik
Committed by
GitHub
Sep 20, 2022
Browse files
Beit postprocessing (#19099)
* add post_process_semantic_segmentation method to BeiTFeatureExtractor
parent
261301d3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
2 deletions
+47
-2
docs/source/en/model_doc/beit.mdx
docs/source/en/model_doc/beit.mdx
+1
-0
src/transformers/models/beit/feature_extraction_beit.py
src/transformers/models/beit/feature_extraction_beit.py
+46
-2
No files found.
docs/source/en/model_doc/beit.mdx
View file @
c81ebd1c
...
...
@@ -82,6 +82,7 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code
[[autodoc]] BeitFeatureExtractor
- __call__
- post_process_semantic_segmentation
## BeitModel
...
...
src/transformers/models/beit/feature_extraction_beit.py
View file @
c81ebd1c
...
...
@@ -14,7 +14,7 @@
# limitations under the License.
"""Feature extractor class for BEiT."""
from
typing
import
Optional
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
from
PIL
import
Image
...
...
@@ -27,9 +27,12 @@ from ...image_utils import (
ImageInput
,
is_torch_tensor
,
)
from
...utils
import
TensorType
,
logging
from
...utils
import
TensorType
,
is_torch_available
,
logging
if
is_torch_available
():
import
torch
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -222,3 +225,44 @@ class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
encoded_inputs
=
BatchFeature
(
data
=
data
,
tensor_type
=
return_tensors
)
return
encoded_inputs
def
post_process_semantic_segmentation
(
self
,
outputs
,
target_sizes
:
Union
[
TensorType
,
List
[
Tuple
]]
=
None
):
"""
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`BeitForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*):
Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to
None, predictions will not be resized.
Returns:
semantic_segmentation: `torch.Tensor` of shape `(batch_size, 2)` or `List[torch.Tensor]` of length
`batch_size`, where each item is a semantic segmentation map of of the corresponding target_sizes entry (if
`target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits
=
outputs
.
logits
if
len
(
logits
)
!=
len
(
target_sizes
):
raise
ValueError
(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if
target_sizes
is
not
None
and
target_sizes
.
shape
[
1
]
!=
2
:
raise
ValueError
(
"Each element of target_sizes must contain the size (h, w) of each image of the batch"
)
semantic_segmentation
=
logits
.
argmax
(
dim
=
1
)
# Resize semantic segmentation maps
if
target_sizes
is
not
None
:
if
is_torch_tensor
(
target_sizes
):
target_sizes
=
target_sizes
.
numpy
()
resized_maps
=
[]
semantic_segmentation
=
semantic_segmentation
.
numpy
()
for
idx
in
range
(
len
(
semantic_segmentation
)):
resized
=
self
.
resize
(
image
=
semantic_segmentation
[
idx
],
size
=
target_sizes
[
idx
])
resized_maps
.
append
(
resized
)
semantic_segmentation
=
[
torch
.
Tensor
(
np
.
array
(
image
))
for
image
in
resized_maps
]
return
semantic_segmentation
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment