Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
84162068
"tools/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "13bec75948a3a00dd0d20c9c5ab5402d2503023d"
Unverified
Commit
84162068
authored
Mar 12, 2022
by
João Gustavo A. Amorim
Committed by
GitHub
Mar 12, 2022
Browse files
apply unpack_input decorator to ViT model (#16102)
parent
62b05b69
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
67 deletions
+24
-67
src/transformers/models/vit/modeling_tf_vit.py
src/transformers/models/vit/modeling_tf_vit.py
+24
-67
No files found.
src/transformers/models/vit/modeling_tf_vit.py
View file @
84162068
...
@@ -30,8 +30,8 @@ from ...modeling_tf_utils import (
...
@@ -30,8 +30,8 @@ from ...modeling_tf_utils import (
TFPreTrainedModel
,
TFPreTrainedModel
,
TFSequenceClassificationLoss
,
TFSequenceClassificationLoss
,
get_initializer
,
get_initializer
,
input_processing
,
keras_serializable
,
keras_serializable
,
unpack_inputs
,
)
)
from
...tf_utils
import
shape_list
from
...tf_utils
import
shape_list
from
...utils
import
logging
from
...utils
import
logging
...
@@ -477,6 +477,7 @@ class TFViTMainLayer(tf.keras.layers.Layer):
...
@@ -477,6 +477,7 @@ class TFViTMainLayer(tf.keras.layers.Layer):
"""
"""
raise
NotImplementedError
raise
NotImplementedError
@
unpack_inputs
def
call
(
def
call
(
self
,
self
,
pixel_values
:
Optional
[
TFModelInputType
]
=
None
,
pixel_values
:
Optional
[
TFModelInputType
]
=
None
,
...
@@ -488,29 +489,14 @@ class TFViTMainLayer(tf.keras.layers.Layer):
...
@@ -488,29 +489,14 @@ class TFViTMainLayer(tf.keras.layers.Layer):
training
:
bool
=
False
,
training
:
bool
=
False
,
**
kwargs
,
**
kwargs
,
)
->
Union
[
TFBaseModelOutputWithPooling
,
Tuple
[
tf
.
Tensor
]]:
)
->
Union
[
TFBaseModelOutputWithPooling
,
Tuple
[
tf
.
Tensor
]]:
inputs
=
input_processing
(
func
=
self
.
call
,
config
=
self
.
config
,
input_ids
=
pixel_values
,
head_mask
=
head_mask
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
interpolate_pos_encoding
=
interpolate_pos_encoding
,
return_dict
=
return_dict
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
"input_ids"
in
inputs
:
inputs
[
"pixel_values"
]
=
inputs
.
pop
(
"input_ids"
)
if
inputs
[
"
pixel_values
"
]
is
None
:
if
pixel_values
is
None
:
raise
ValueError
(
"You have to specify pixel_values"
)
raise
ValueError
(
"You have to specify pixel_values"
)
embedding_output
=
self
.
embeddings
(
embedding_output
=
self
.
embeddings
(
pixel_values
=
inputs
[
"
pixel_values
"
]
,
pixel_values
=
pixel_values
,
interpolate_pos_encoding
=
inputs
[
"
interpolate_pos_encoding
"
]
,
interpolate_pos_encoding
=
interpolate_pos_encoding
,
training
=
inputs
[
"
training
"
]
,
training
=
training
,
)
)
# Prepare head mask if needed
# Prepare head mask if needed
...
@@ -518,25 +504,25 @@ class TFViTMainLayer(tf.keras.layers.Layer):
...
@@ -518,25 +504,25 @@ class TFViTMainLayer(tf.keras.layers.Layer):
# attention_probs has shape bsz x n_heads x N x N
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if
inputs
[
"
head_mask
"
]
is
not
None
:
if
head_mask
is
not
None
:
raise
NotImplementedError
raise
NotImplementedError
else
:
else
:
inputs
[
"
head_mask
"
]
=
[
None
]
*
self
.
config
.
num_hidden_layers
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
encoder_outputs
=
self
.
encoder
(
encoder_outputs
=
self
.
encoder
(
hidden_states
=
embedding_output
,
hidden_states
=
embedding_output
,
head_mask
=
inputs
[
"
head_mask
"
]
,
head_mask
=
head_mask
,
output_attentions
=
inputs
[
"
output_attentions
"
]
,
output_attentions
=
output_attentions
,
output_hidden_states
=
inputs
[
"
output_hidden_states
"
]
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
inputs
[
"
return_dict
"
]
,
return_dict
=
return_dict
,
training
=
inputs
[
"
training
"
]
,
training
=
training
,
)
)
sequence_output
=
encoder_outputs
[
0
]
sequence_output
=
encoder_outputs
[
0
]
sequence_output
=
self
.
layernorm
(
inputs
=
sequence_output
)
sequence_output
=
self
.
layernorm
(
inputs
=
sequence_output
)
pooled_output
=
self
.
pooler
(
hidden_states
=
sequence_output
)
if
self
.
pooler
is
not
None
else
None
pooled_output
=
self
.
pooler
(
hidden_states
=
sequence_output
)
if
self
.
pooler
is
not
None
else
None
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
return
(
sequence_output
,
pooled_output
)
+
encoder_outputs
[
1
:]
return
(
sequence_output
,
pooled_output
)
+
encoder_outputs
[
1
:]
return
TFBaseModelOutputWithPooling
(
return
TFBaseModelOutputWithPooling
(
...
@@ -659,6 +645,7 @@ class TFViTModel(TFViTPreTrainedModel):
...
@@ -659,6 +645,7 @@ class TFViTModel(TFViTPreTrainedModel):
self
.
vit
=
TFViTMainLayer
(
config
,
add_pooling_layer
=
add_pooling_layer
,
name
=
"vit"
)
self
.
vit
=
TFViTMainLayer
(
config
,
add_pooling_layer
=
add_pooling_layer
,
name
=
"vit"
)
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
VIT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
VIT_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
TFBaseModelOutputWithPooling
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
TFBaseModelOutputWithPooling
,
config_class
=
_CONFIG_FOR_DOC
)
def
call
(
def
call
(
...
@@ -692,30 +679,15 @@ class TFViTModel(TFViTPreTrainedModel):
...
@@ -692,30 +679,15 @@ class TFViTModel(TFViTPreTrainedModel):
>>> outputs = model(**inputs)
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
>>> last_hidden_states = outputs.last_hidden_state
```"""
```"""
inputs
=
input_processing
(
func
=
self
.
call
,
outputs
=
self
.
vit
(
config
=
self
.
config
,
pixel_values
=
pixel_values
,
input_ids
=
pixel_values
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
interpolate_pos_encoding
=
interpolate_pos_encoding
,
interpolate_pos_encoding
=
interpolate_pos_encoding
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
"input_ids"
in
inputs
:
inputs
[
"pixel_values"
]
=
inputs
.
pop
(
"input_ids"
)
outputs
=
self
.
vit
(
pixel_values
=
inputs
[
"pixel_values"
],
head_mask
=
inputs
[
"head_mask"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
interpolate_pos_encoding
=
inputs
[
"interpolate_pos_encoding"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
)
return
outputs
return
outputs
...
@@ -773,6 +745,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
...
@@ -773,6 +745,7 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
name
=
"classifier"
,
name
=
"classifier"
,
)
)
@
unpack_inputs
@
add_start_docstrings_to_model_forward
(
VIT_INPUTS_DOCSTRING
)
@
add_start_docstrings_to_model_forward
(
VIT_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
TFSequenceClassifierOutput
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
TFSequenceClassifierOutput
,
config_class
=
_CONFIG_FOR_DOC
)
def
call
(
def
call
(
...
@@ -816,37 +789,21 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
...
@@ -816,37 +789,21 @@ class TFViTForImageClassification(TFViTPreTrainedModel, TFSequenceClassification
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
>>> predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
>>> print("Predicted class:", model.config.id2label[int(predicted_class_idx)])
```"""
```"""
inputs
=
input_processing
(
func
=
self
.
call
,
outputs
=
self
.
vit
(
config
=
self
.
config
,
pixel_values
=
pixel_values
,
input_ids
=
pixel_values
,
head_mask
=
head_mask
,
head_mask
=
head_mask
,
output_attentions
=
output_attentions
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
output_hidden_states
=
output_hidden_states
,
interpolate_pos_encoding
=
interpolate_pos_encoding
,
interpolate_pos_encoding
=
interpolate_pos_encoding
,
return_dict
=
return_dict
,
return_dict
=
return_dict
,
labels
=
labels
,
training
=
training
,
training
=
training
,
kwargs_call
=
kwargs
,
)
if
"input_ids"
in
inputs
:
inputs
[
"pixel_values"
]
=
inputs
.
pop
(
"input_ids"
)
outputs
=
self
.
vit
(
pixel_values
=
inputs
[
"pixel_values"
],
head_mask
=
inputs
[
"head_mask"
],
output_attentions
=
inputs
[
"output_attentions"
],
output_hidden_states
=
inputs
[
"output_hidden_states"
],
interpolate_pos_encoding
=
inputs
[
"interpolate_pos_encoding"
],
return_dict
=
inputs
[
"return_dict"
],
training
=
inputs
[
"training"
],
)
)
sequence_output
=
outputs
[
0
]
sequence_output
=
outputs
[
0
]
logits
=
self
.
classifier
(
inputs
=
sequence_output
[:,
0
,
:])
logits
=
self
.
classifier
(
inputs
=
sequence_output
[:,
0
,
:])
loss
=
None
if
inputs
[
"
labels
"
]
is
None
else
self
.
hf_compute_loss
(
labels
=
inputs
[
"
labels
"
]
,
logits
=
logits
)
loss
=
None
if
labels
is
None
else
self
.
hf_compute_loss
(
labels
=
labels
,
logits
=
logits
)
if
not
inputs
[
"
return_dict
"
]
:
if
not
return_dict
:
output
=
(
logits
,)
+
outputs
[
2
:]
output
=
(
logits
,)
+
outputs
[
2
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment