Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c81ebd1c
Unverified
Commit
c81ebd1c
authored
Sep 20, 2022
by
Alara Dirik
Committed by
GitHub
Sep 20, 2022
Browse files
Beit postprocessing (#19099)
* add post_process_semantic_segmentation method to BeiTFeatureExtractor
parent
261301d3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
2 deletions
+47
-2
docs/source/en/model_doc/beit.mdx
docs/source/en/model_doc/beit.mdx
+1
-0
src/transformers/models/beit/feature_extraction_beit.py
src/transformers/models/beit/feature_extraction_beit.py
+46
-2
No files found.
docs/source/en/model_doc/beit.mdx
View file @
c81ebd1c
...
...
@@ -82,6 +82,7 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code
[[autodoc]] BeitFeatureExtractor
- __call__
- post_process_semantic_segmentation
## BeitModel
...
...
src/transformers/models/beit/feature_extraction_beit.py
View file @
c81ebd1c
...
...
@@ -14,7 +14,7 @@
# limitations under the License.
"""Feature extractor class for BEiT."""
from
typing
import
Optional
,
Union
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
from
PIL
import
Image
...
...
@@ -27,9 +27,12 @@ from ...image_utils import (
ImageInput
,
is_torch_tensor
,
)
from
...utils
import
TensorType
,
logging
from
...utils
import
TensorType
,
is_torch_available
,
logging
if
is_torch_available
():
import
torch
logger
=
logging
.
get_logger
(
__name__
)
...
...
@@ -222,3 +225,44 @@ class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
encoded_inputs
=
BatchFeature
(
data
=
data
,
tensor_type
=
return_tensors
)
return
encoded_inputs
def
post_process_semantic_segmentation
(
self
,
outputs
,
target_sizes
:
Union
[
TensorType
,
List
[
Tuple
]]
=
None
):
"""
Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch.
Args:
outputs ([`BeitForSemanticSegmentation`]):
Raw outputs of the model.
target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*):
Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to
None, predictions will not be resized.
Returns:
semantic_segmentation: `torch.Tensor` of shape `(batch_size, 2)` or `List[torch.Tensor]` of length
`batch_size`, where each item is a semantic segmentation map of of the corresponding target_sizes entry (if
`target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id.
"""
logits
=
outputs
.
logits
if
len
(
logits
)
!=
len
(
target_sizes
):
raise
ValueError
(
"Make sure that you pass in as many target sizes as the batch dimension of the logits"
)
if
target_sizes
is
not
None
and
target_sizes
.
shape
[
1
]
!=
2
:
raise
ValueError
(
"Each element of target_sizes must contain the size (h, w) of each image of the batch"
)
semantic_segmentation
=
logits
.
argmax
(
dim
=
1
)
# Resize semantic segmentation maps
if
target_sizes
is
not
None
:
if
is_torch_tensor
(
target_sizes
):
target_sizes
=
target_sizes
.
numpy
()
resized_maps
=
[]
semantic_segmentation
=
semantic_segmentation
.
numpy
()
for
idx
in
range
(
len
(
semantic_segmentation
)):
resized
=
self
.
resize
(
image
=
semantic_segmentation
[
idx
],
size
=
target_sizes
[
idx
])
resized_maps
.
append
(
resized
)
semantic_segmentation
=
[
torch
.
Tensor
(
np
.
array
(
image
))
for
image
in
resized_maps
]
return
semantic_segmentation
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment