Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
0558914d
Unverified
Commit
0558914d
authored
Mar 22, 2023
by
Alara Dirik
Committed by
GitHub
Mar 22, 2023
Browse files
Add MaskedImageModelingOutput (#22212)
* Add MaskedImageModelingOutput
parent
0dcb46e7
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
116 additions
and
28 deletions
+116
-28
src/transformers/modeling_outputs.py
src/transformers/modeling_outputs.py
+38
-0
src/transformers/modeling_tf_outputs.py
src/transformers/modeling_tf_outputs.py
+40
-2
src/transformers/models/deit/modeling_deit.py
src/transformers/models/deit/modeling_deit.py
+11
-6
src/transformers/models/deit/modeling_tf_deit.py
src/transformers/models/deit/modeling_tf_deit.py
+10
-8
src/transformers/models/vit/modeling_vit.py
src/transformers/models/vit/modeling_vit.py
+11
-6
tests/models/deit/test_modeling_deit.py
tests/models/deit/test_modeling_deit.py
+2
-2
tests/models/deit/test_modeling_tf_deit.py
tests/models/deit/test_modeling_tf_deit.py
+2
-2
tests/models/vit/test_modeling_vit.py
tests/models/vit/test_modeling_vit.py
+2
-2
No files found.
src/transformers/modeling_outputs.py
View file @
0558914d
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
warnings
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
...
@@ -1622,3 +1623,40 @@ class SampleTSPredictionOutput(ModelOutput):
...
@@ -1622,3 +1623,40 @@ class SampleTSPredictionOutput(ModelOutput):
"""
"""
sequences
:
torch
.
FloatTensor
=
None
sequences
:
torch
.
FloatTensor
=
None
@
dataclass
class
MaskedImageModelingOutput
(
ModelOutput
):
"""
Base class for outputs of masked image completion / in-painting models.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Reconstruction loss.
reconstruction (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed / completed images.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or
when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states
(also called feature maps) of the model at the output of each stage.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when
`config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, patch_size,
sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in
the self-attention heads.
"""
loss
:
Optional
[
torch
.
FloatTensor
]
=
None
reconstruction
:
torch
.
FloatTensor
=
None
hidden_states
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
=
None
attentions
:
Optional
[
Tuple
[
torch
.
FloatTensor
]]
=
None
@
property
def
logits
(
self
):
warnings
.
warn
(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead."
,
FutureWarning
,
)
return
self
.
reconstruction
src/transformers/modeling_tf_outputs.py
View file @
0558914d
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
warnings
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Tuple
from
typing
import
List
,
Optional
,
Tuple
...
@@ -55,8 +56,8 @@ class TFBaseModelOutputWithNoAttention(ModelOutput):
...
@@ -55,8 +56,8 @@ class TFBaseModelOutputWithNoAttention(ModelOutput):
last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`):
last_hidden_state (`tf.Tensor` shape `(batch_size, num_channels, height, width)`):
Sequence of hidden-states at the output of the last layer of the model.
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `t
orch.Float
Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
Tuple of `t
f.
Tensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for
one for
the output of each layer) of shape `(batch_size, num_channels, height, width)`.
the output of each layer) of shape `(batch_size, num_channels, height, width)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
"""
"""
...
@@ -949,3 +950,40 @@ class TFImageClassifierOutputWithNoAttention(ModelOutput):
...
@@ -949,3 +950,40 @@ class TFImageClassifierOutputWithNoAttention(ModelOutput):
loss
:
Optional
[
tf
.
Tensor
]
=
None
loss
:
Optional
[
tf
.
Tensor
]
=
None
logits
:
tf
.
Tensor
=
None
logits
:
tf
.
Tensor
=
None
hidden_states
:
Optional
[
Tuple
[
tf
.
Tensor
,
...]]
=
None
hidden_states
:
Optional
[
Tuple
[
tf
.
Tensor
,
...]]
=
None
@
dataclass
class
TFMaskedImageModelingOutput
(
ModelOutput
):
"""
Base class for outputs of masked image completion / in-painting models.
Args:
loss (`tf.Tensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided):
Reconstruction loss.
reconstruction (`tf.Tensor` of shape `(batch_size, num_channels, height, width)`):
Reconstructed / completed images.
hidden_states (`tuple(tf.Tensor)`, *optional*, returned when `output_hidden_states=True` is passed or when
`config.output_hidden_states=True`):
Tuple of `tf.Tensor` (one for the output of the embeddings, if the model has an embedding layer, + one for
the output of each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called
feature maps) of the model at the output of each stage.
attentions (`tuple(tf.Tensor)`, *optional*, returned when `output_attentions=True` is passed or when
`config.output_attentions=True`):
Tuple of `tf.Tensor` (one for each layer) of shape `(batch_size, num_heads, patch_size, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
"""
loss
:
Optional
[
tf
.
Tensor
]
=
None
reconstruction
:
tf
.
Tensor
=
None
hidden_states
:
Optional
[
Tuple
[
tf
.
Tensor
]]
=
None
attentions
:
Optional
[
Tuple
[
tf
.
Tensor
]]
=
None
@
property
def
logits
(
self
):
warnings
.
warn
(
"logits attribute is deprecated and will be removed in version 5 of Transformers."
" Please use the reconstruction attribute to retrieve the final output instead."
,
FutureWarning
,
)
return
self
.
reconstruction
src/transformers/models/deit/modeling_deit.py
View file @
0558914d
...
@@ -26,7 +26,12 @@ from torch import nn
...
@@ -26,7 +26,12 @@ from torch import nn
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
,
ImageClassifierOutput
,
MaskedLMOutput
from
...modeling_outputs
import
(
BaseModelOutput
,
BaseModelOutputWithPooling
,
ImageClassifierOutput
,
MaskedImageModelingOutput
,
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
find_pruneable_heads_and_indices
,
prune_linear_layer
from
...pytorch_utils
import
find_pruneable_heads_and_indices
,
prune_linear_layer
from
...utils
import
(
from
...utils
import
(
...
@@ -592,7 +597,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
...
@@ -592,7 +597,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
self
.
post_init
()
self
.
post_init
()
@
add_start_docstrings_to_model_forward
(
DEIT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
DEIT_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
Masked
LM
Output
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
Masked
ImageModeling
Output
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
def
forward
(
self
,
self
,
pixel_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pixel_values
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -601,7 +606,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
...
@@ -601,7 +606,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
output_attentions
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
tuple
,
Masked
LM
Output
]:
)
->
Union
[
tuple
,
Masked
ImageModeling
Output
]:
r
"""
r
"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
...
@@ -627,7 +632,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
...
@@ -627,7 +632,7 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.
logits
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.
reconstruction
>>> list(reconstructed_pixel_values.shape)
>>> list(reconstructed_pixel_values.shape)
[1, 3, 224, 224]
[1, 3, 224, 224]
```"""
```"""
...
@@ -670,9 +675,9 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
...
@@ -670,9 +675,9 @@ class DeiTForMaskedImageModeling(DeiTPreTrainedModel):
output
=
(
reconstructed_pixel_values
,)
+
outputs
[
1
:]
output
=
(
reconstructed_pixel_values
,)
+
outputs
[
1
:]
return
((
masked_im_loss
,)
+
output
)
if
masked_im_loss
is
not
None
else
output
return
((
masked_im_loss
,)
+
output
)
if
masked_im_loss
is
not
None
else
output
return
Masked
LM
Output
(
return
Masked
ImageModeling
Output
(
loss
=
masked_im_loss
,
loss
=
masked_im_loss
,
logits
=
reconstructed_pixel_values
,
reconstruction
=
reconstructed_pixel_values
,
hidden_states
=
outputs
.
hidden_states
,
hidden_states
=
outputs
.
hidden_states
,
attentions
=
outputs
.
attentions
,
attentions
=
outputs
.
attentions
,
)
)
...
...
src/transformers/models/deit/modeling_tf_deit.py
View file @
0558914d
...
@@ -27,7 +27,7 @@ from ...modeling_tf_outputs import (
...
@@ -27,7 +27,7 @@ from ...modeling_tf_outputs import (
TFBaseModelOutput
,
TFBaseModelOutput
,
TFBaseModelOutputWithPooling
,
TFBaseModelOutputWithPooling
,
TFImageClassifierOutput
,
TFImageClassifierOutput
,
TFMasked
LM
Output
,
TFMasked
ImageModeling
Output
,
)
)
from
...modeling_tf_utils
import
(
from
...modeling_tf_utils
import
(
TFPreTrainedModel
,
TFPreTrainedModel
,
...
@@ -769,7 +769,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
...
@@ -769,7 +769,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
@
unpack_inputs
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
DEIT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
DEIT_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
TFMasked
LM
Output
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
TFMasked
ImageModeling
Output
,
config_class
=
_CONFIG_FOR_DOC
)
def
call
(
def
call
(
self
,
self
,
pixel_values
:
Optional
[
tf
.
Tensor
]
=
None
,
pixel_values
:
Optional
[
tf
.
Tensor
]
=
None
,
...
@@ -779,7 +779,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
...
@@ -779,7 +779,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
training
:
bool
=
False
,
training
:
bool
=
False
,
)
->
Union
[
tuple
,
TFMasked
LM
Output
]:
)
->
Union
[
tuple
,
TFMasked
ImageModeling
Output
]:
r
"""
r
"""
bool_masked_pos (`tf.Tensor` of type bool and shape `(batch_size, num_patches)`):
bool_masked_pos (`tf.Tensor` of type bool and shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
...
@@ -805,7 +805,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
...
@@ -805,7 +805,7 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
>>> bool_masked_pos = tf.cast(tf.random.uniform((1, num_patches), minval=0, maxval=2, dtype=tf.int32), tf.bool)
>>> bool_masked_pos = tf.cast(tf.random.uniform((1, num_patches), minval=0, maxval=2, dtype=tf.int32), tf.bool)
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.
logits
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.
reconstruction
>>> list(reconstructed_pixel_values.shape)
>>> list(reconstructed_pixel_values.shape)
[1, 3, 224, 224]
[1, 3, 224, 224]
```"""
```"""
...
@@ -860,18 +860,20 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
...
@@ -860,18 +860,20 @@ class TFDeiTForMaskedImageModeling(TFDeiTPreTrainedModel):
output
=
(
reconstructed_pixel_values
,)
+
outputs
[
1
:]
output
=
(
reconstructed_pixel_values
,)
+
outputs
[
1
:]
return
((
masked_im_loss
,)
+
output
)
if
masked_im_loss
is
not
None
else
output
return
((
masked_im_loss
,)
+
output
)
if
masked_im_loss
is
not
None
else
output
return
TFMasked
LM
Output
(
return
TFMasked
ImageModeling
Output
(
loss
=
masked_im_loss
,
loss
=
masked_im_loss
,
logits
=
reconstructed_pixel_values
,
reconstruction
=
reconstructed_pixel_values
,
hidden_states
=
outputs
.
hidden_states
,
hidden_states
=
outputs
.
hidden_states
,
attentions
=
outputs
.
attentions
,
attentions
=
outputs
.
attentions
,
)
)
def
serving_output
(
self
,
output
:
TFMasked
LM
Output
)
->
TFMasked
LM
Output
:
def
serving_output
(
self
,
output
:
TFMasked
ImageModeling
Output
)
->
TFMasked
ImageModeling
Output
:
hidden_states
=
tf
.
convert_to_tensor
(
output
.
hidden_states
)
if
self
.
config
.
output_hidden_states
else
None
hidden_states
=
tf
.
convert_to_tensor
(
output
.
hidden_states
)
if
self
.
config
.
output_hidden_states
else
None
attentions
=
tf
.
convert_to_tensor
(
output
.
attentions
)
if
self
.
config
.
output_attentions
else
None
attentions
=
tf
.
convert_to_tensor
(
output
.
attentions
)
if
self
.
config
.
output_attentions
else
None
return
TFMaskedLMOutput
(
logits
=
output
.
logits
,
hidden_states
=
hidden_states
,
attentions
=
attentions
)
return
TFMaskedImageModelingOutput
(
reconstruction
=
output
.
reconstruction
,
hidden_states
=
hidden_states
,
attentions
=
attentions
)
@
add_start_docstrings
(
@
add_start_docstrings
(
...
...
src/transformers/models/vit/modeling_vit.py
View file @
0558914d
...
@@ -25,7 +25,12 @@ from torch import nn
...
@@ -25,7 +25,12 @@ from torch import nn
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
...activations
import
ACT2FN
from
...activations
import
ACT2FN
from
...modeling_outputs
import
BaseModelOutput
,
BaseModelOutputWithPooling
,
ImageClassifierOutput
,
MaskedLMOutput
from
...modeling_outputs
import
(
BaseModelOutput
,
BaseModelOutputWithPooling
,
ImageClassifierOutput
,
MaskedImageModelingOutput
,
)
from
...modeling_utils
import
PreTrainedModel
from
...modeling_utils
import
PreTrainedModel
from
...pytorch_utils
import
find_pruneable_heads_and_indices
,
prune_linear_layer
from
...pytorch_utils
import
find_pruneable_heads_and_indices
,
prune_linear_layer
from
...utils
import
(
from
...utils
import
(
...
@@ -647,7 +652,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
...
@@ -647,7 +652,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
self
.
post_init
()
self
.
post_init
()
@
add_start_docstrings_to_model_forward
(
VIT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
VIT_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
Masked
LM
Output
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
Masked
ImageModeling
Output
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
def
forward
(
self
,
self
,
pixel_values
:
Optional
[
torch
.
Tensor
]
=
None
,
pixel_values
:
Optional
[
torch
.
Tensor
]
=
None
,
...
@@ -657,7 +662,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
...
@@ -657,7 +662,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
output_hidden_states
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
interpolate_pos_encoding
:
Optional
[
bool
]
=
None
,
interpolate_pos_encoding
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
tuple
,
Masked
LM
Output
]:
)
->
Union
[
tuple
,
Masked
ImageModeling
Output
]:
r
"""
r
"""
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`):
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
...
@@ -683,7 +688,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
...
@@ -683,7 +688,7 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
>>> bool_masked_pos = torch.randint(low=0, high=2, size=(1, num_patches)).bool()
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> outputs = model(pixel_values, bool_masked_pos=bool_masked_pos)
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.
logits
>>> loss, reconstructed_pixel_values = outputs.loss, outputs.
reconstruction
>>> list(reconstructed_pixel_values.shape)
>>> list(reconstructed_pixel_values.shape)
[1, 3, 224, 224]
[1, 3, 224, 224]
```"""
```"""
...
@@ -727,9 +732,9 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
...
@@ -727,9 +732,9 @@ class ViTForMaskedImageModeling(ViTPreTrainedModel):
output
=
(
reconstructed_pixel_values
,)
+
outputs
[
1
:]
output
=
(
reconstructed_pixel_values
,)
+
outputs
[
1
:]
return
((
masked_im_loss
,)
+
output
)
if
masked_im_loss
is
not
None
else
output
return
((
masked_im_loss
,)
+
output
)
if
masked_im_loss
is
not
None
else
output
return
Masked
LM
Output
(
return
Masked
ImageModeling
Output
(
loss
=
masked_im_loss
,
loss
=
masked_im_loss
,
logits
=
reconstructed_pixel_values
,
reconstruction
=
reconstructed_pixel_values
,
hidden_states
=
outputs
.
hidden_states
,
hidden_states
=
outputs
.
hidden_states
,
attentions
=
outputs
.
attentions
,
attentions
=
outputs
.
attentions
,
)
)
...
...
tests/models/deit/test_modeling_deit.py
View file @
0558914d
...
@@ -145,7 +145,7 @@ class DeiTModelTester:
...
@@ -145,7 +145,7 @@ class DeiTModelTester:
model
.
eval
()
model
.
eval
()
result
=
model
(
pixel_values
)
result
=
model
(
pixel_values
)
self
.
parent
.
assertEqual
(
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
num_channels
,
self
.
image_size
,
self
.
image_size
)
result
.
reconstruction
.
shape
,
(
self
.
batch_size
,
self
.
num_channels
,
self
.
image_size
,
self
.
image_size
)
)
)
# test greyscale images
# test greyscale images
...
@@ -156,7 +156,7 @@ class DeiTModelTester:
...
@@ -156,7 +156,7 @@ class DeiTModelTester:
pixel_values
=
floats_tensor
([
self
.
batch_size
,
1
,
self
.
image_size
,
self
.
image_size
])
pixel_values
=
floats_tensor
([
self
.
batch_size
,
1
,
self
.
image_size
,
self
.
image_size
])
result
=
model
(
pixel_values
)
result
=
model
(
pixel_values
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
1
,
self
.
image_size
,
self
.
image_size
))
self
.
parent
.
assertEqual
(
result
.
reconstruction
.
shape
,
(
self
.
batch_size
,
1
,
self
.
image_size
,
self
.
image_size
))
def
create_and_check_for_image_classification
(
self
,
config
,
pixel_values
,
labels
):
def
create_and_check_for_image_classification
(
self
,
config
,
pixel_values
,
labels
):
config
.
num_labels
=
self
.
type_sequence_label_size
config
.
num_labels
=
self
.
type_sequence_label_size
...
...
tests/models/deit/test_modeling_tf_deit.py
View file @
0558914d
...
@@ -130,7 +130,7 @@ class TFDeiTModelTester:
...
@@ -130,7 +130,7 @@ class TFDeiTModelTester:
model
=
TFDeiTForMaskedImageModeling
(
config
=
config
)
model
=
TFDeiTForMaskedImageModeling
(
config
=
config
)
result
=
model
(
pixel_values
)
result
=
model
(
pixel_values
)
self
.
parent
.
assertEqual
(
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
num_channels
,
self
.
image_size
,
self
.
image_size
)
result
.
reconstruction
.
shape
,
(
self
.
batch_size
,
self
.
num_channels
,
self
.
image_size
,
self
.
image_size
)
)
)
# test greyscale images
# test greyscale images
...
@@ -139,7 +139,7 @@ class TFDeiTModelTester:
...
@@ -139,7 +139,7 @@ class TFDeiTModelTester:
pixel_values
=
floats_tensor
([
self
.
batch_size
,
1
,
self
.
image_size
,
self
.
image_size
])
pixel_values
=
floats_tensor
([
self
.
batch_size
,
1
,
self
.
image_size
,
self
.
image_size
])
result
=
model
(
pixel_values
)
result
=
model
(
pixel_values
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
1
,
self
.
image_size
,
self
.
image_size
))
self
.
parent
.
assertEqual
(
result
.
reconstruction
.
shape
,
(
self
.
batch_size
,
1
,
self
.
image_size
,
self
.
image_size
))
def
create_and_check_for_image_classification
(
self
,
config
,
pixel_values
,
labels
):
def
create_and_check_for_image_classification
(
self
,
config
,
pixel_values
,
labels
):
config
.
num_labels
=
self
.
type_sequence_label_size
config
.
num_labels
=
self
.
type_sequence_label_size
...
...
tests/models/vit/test_modeling_vit.py
View file @
0558914d
...
@@ -134,7 +134,7 @@ class ViTModelTester:
...
@@ -134,7 +134,7 @@ class ViTModelTester:
model
.
eval
()
model
.
eval
()
result
=
model
(
pixel_values
)
result
=
model
(
pixel_values
)
self
.
parent
.
assertEqual
(
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
num_channels
,
self
.
image_size
,
self
.
image_size
)
result
.
reconstruction
.
shape
,
(
self
.
batch_size
,
self
.
num_channels
,
self
.
image_size
,
self
.
image_size
)
)
)
# test greyscale images
# test greyscale images
...
@@ -145,7 +145,7 @@ class ViTModelTester:
...
@@ -145,7 +145,7 @@ class ViTModelTester:
pixel_values
=
floats_tensor
([
self
.
batch_size
,
1
,
self
.
image_size
,
self
.
image_size
])
pixel_values
=
floats_tensor
([
self
.
batch_size
,
1
,
self
.
image_size
,
self
.
image_size
])
result
=
model
(
pixel_values
)
result
=
model
(
pixel_values
)
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
1
,
self
.
image_size
,
self
.
image_size
))
self
.
parent
.
assertEqual
(
result
.
reconstruction
.
shape
,
(
self
.
batch_size
,
1
,
self
.
image_size
,
self
.
image_size
))
def
create_and_check_for_image_classification
(
self
,
config
,
pixel_values
,
labels
):
def
create_and_check_for_image_classification
(
self
,
config
,
pixel_values
,
labels
):
config
.
num_labels
=
self
.
type_sequence_label_size
config
.
num_labels
=
self
.
type_sequence_label_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment