Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7f3d4440
Unverified
Commit
7f3d4440
authored
Mar 11, 2022
by
João Gustavo A. Amorim
Committed by
GitHub
Mar 11, 2022
Browse files
add type annotations for ImageGPT (#16088)
parent
5b4c97d0
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
72 additions
and
72 deletions
+72
-72
src/transformers/models/imagegpt/modeling_imagegpt.py
src/transformers/models/imagegpt/modeling_imagegpt.py
+72
-72
No files found.
src/transformers/models/imagegpt/modeling_imagegpt.py
View file @
7f3d4440
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
import
math
import
math
import
os
import
os
import
warnings
import
warnings
from
typing
import
Tuple
from
typing
import
Any
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.utils.checkpoint
import
torch.utils.checkpoint
...
@@ -167,12 +167,12 @@ def load_tf_weights_in_imagegpt(model, config, imagegpt_checkpoint_path):
...
@@ -167,12 +167,12 @@ def load_tf_weights_in_imagegpt(model, config, imagegpt_checkpoint_path):
class
ImageGPTLayerNorm
(
nn
.
Module
):
class
ImageGPTLayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-5
):
def
__init__
(
self
,
hidden_size
:
Tuple
[
int
],
eps
:
float
=
1e-5
):
super
().
__init__
()
super
().
__init__
()
self
.
eps
=
eps
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
hidden_size
))
self
.
weight
=
nn
.
Parameter
(
torch
.
Tensor
(
hidden_size
))
def
forward
(
self
,
tensor
)
:
def
forward
(
self
,
tensor
:
torch
.
Tensor
)
->
tuple
:
# input is not mean centered
# input is not mean centered
return
(
return
(
tensor
tensor
...
@@ -182,7 +182,7 @@ class ImageGPTLayerNorm(nn.Module):
...
@@ -182,7 +182,7 @@ class ImageGPTLayerNorm(nn.Module):
class
ImageGPTAttention
(
nn
.
Module
):
class
ImageGPTAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
,
is_cross_attention
=
False
,
layer_idx
=
None
):
def
__init__
(
self
,
config
,
is_cross_attention
:
Optional
[
bool
]
=
False
,
layer_idx
:
Optional
[
int
]
=
None
):
super
().
__init__
()
super
().
__init__
()
max_positions
=
config
.
max_position_embeddings
max_positions
=
config
.
max_position_embeddings
...
@@ -343,15 +343,15 @@ class ImageGPTAttention(nn.Module):
...
@@ -343,15 +343,15 @@ class ImageGPTAttention(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
,
hidden_states
:
torch
.
Tensor
,
layer_past
=
None
,
layer_past
:
Optional
[
bool
]
=
None
,
attention_mask
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_attention_mask
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
output_attentions
=
False
,
output_attentions
:
Optional
[
bool
]
=
False
,
):
)
->
tuple
:
if
encoder_hidden_states
is
not
None
:
if
encoder_hidden_states
is
not
None
:
if
not
hasattr
(
self
,
"q_attn"
):
if
not
hasattr
(
self
,
"q_attn"
):
raise
ValueError
(
raise
ValueError
(
...
@@ -404,7 +404,7 @@ class ImageGPTMLP(nn.Module):
...
@@ -404,7 +404,7 @@ class ImageGPTMLP(nn.Module):
self
.
act
=
ACT2FN
[
config
.
activation_function
]
self
.
act
=
ACT2FN
[
config
.
activation_function
]
self
.
dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
def
forward
(
self
,
hidden_states
)
:
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
self
.
c_fc
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
act
(
hidden_states
)
hidden_states
=
self
.
c_proj
(
hidden_states
)
hidden_states
=
self
.
c_proj
(
hidden_states
)
...
@@ -430,15 +430,15 @@ class ImageGPTBlock(nn.Module):
...
@@ -430,15 +430,15 @@ class ImageGPTBlock(nn.Module):
def
forward
(
def
forward
(
self
,
self
,
hidden_states
,
hidden_states
:
torch
.
Tensor
,
layer_past
=
None
,
layer_past
:
Optional
[
bool
]
=
None
,
attention_mask
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_attention_mask
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
output_attentions
=
False
,
output_attentions
:
Optional
[
bool
]
=
False
,
):
)
->
tuple
:
residual
=
hidden_states
residual
=
hidden_states
hidden_states
=
self
.
ln_1
(
hidden_states
)
hidden_states
=
self
.
ln_1
(
hidden_states
)
attn_outputs
=
self
.
attn
(
attn_outputs
=
self
.
attn
(
...
@@ -620,7 +620,7 @@ IMAGEGPT_INPUTS_DOCSTRING = r"""
...
@@ -620,7 +620,7 @@ IMAGEGPT_INPUTS_DOCSTRING = r"""
class
ImageGPTModel
(
ImageGPTPreTrainedModel
):
class
ImageGPTModel
(
ImageGPTPreTrainedModel
):
_keys_to_ignore_on_load_missing
=
[
"attn.masked_bias"
]
_keys_to_ignore_on_load_missing
=
[
"attn.masked_bias"
]
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
:
ImageGPTConfig
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
embed_dim
=
config
.
hidden_size
self
.
embed_dim
=
config
.
hidden_size
...
@@ -656,21 +656,21 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
...
@@ -656,21 +656,21 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
@
replace_return_docstrings
(
output_type
=
BaseModelOutputWithPastAndCrossAttentions
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
BaseModelOutputWithPastAndCrossAttentions
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
=
None
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
]]]
=
None
,
attention_mask
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_attention_mask
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
**
kwargs
,
**
kwargs
:
Any
,
):
)
->
Union
[
Tuple
,
BaseModelOutputWithPastAndCrossAttentions
]
:
r
"""
r
"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
...
@@ -900,7 +900,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
...
@@ -900,7 +900,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
class
ImageGPTForCausalImageModeling
(
ImageGPTPreTrainedModel
):
class
ImageGPTForCausalImageModeling
(
ImageGPTPreTrainedModel
):
_keys_to_ignore_on_load_missing
=
[
r
"attn.masked_bias"
,
r
"attn.bias"
,
r
"lm_head.weight"
]
_keys_to_ignore_on_load_missing
=
[
r
"attn.masked_bias"
,
r
"attn.bias"
,
r
"lm_head.weight"
]
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
:
ImageGPTConfig
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
transformer
=
ImageGPTModel
(
config
)
self
.
transformer
=
ImageGPTModel
(
config
)
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
-
1
,
bias
=
False
)
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
-
1
,
bias
=
False
)
...
@@ -917,7 +917,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
...
@@ -917,7 +917,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
def
set_output_embeddings
(
self
,
new_embeddings
):
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
lm_head
=
new_embeddings
self
.
lm_head
=
new_embeddings
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past
=
None
,
**
kwargs
):
def
prepare_inputs_for_generation
(
self
,
input_ids
:
torch
.
Tensor
,
past
:
Optional
[
bool
]
=
None
,
**
kwargs
):
token_type_ids
=
kwargs
.
get
(
"token_type_ids"
,
None
)
token_type_ids
=
kwargs
.
get
(
"token_type_ids"
,
None
)
# only last token for inputs_ids if past is defined in kwargs
# only last token for inputs_ids if past is defined in kwargs
if
past
:
if
past
:
...
@@ -949,22 +949,22 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
...
@@ -949,22 +949,22 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
@
replace_return_docstrings
(
output_type
=
CausalLMOutputWithCrossAttentions
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
CausalLMOutputWithCrossAttentions
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
=
None
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
]]]
=
None
,
attention_mask
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_hidden_states
=
None
,
encoder_hidden_states
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_attention_mask
=
None
,
encoder_attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
labels
=
None
,
labels
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
**
kwargs
,
**
kwargs
:
Any
,
):
)
->
Union
[
Tuple
,
CausalLMOutputWithCrossAttentions
]
:
r
"""
r
"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
...
@@ -1088,7 +1088,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
...
@@ -1088,7 +1088,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
class
ImageGPTForImageClassification
(
ImageGPTPreTrainedModel
):
class
ImageGPTForImageClassification
(
ImageGPTPreTrainedModel
):
_keys_to_ignore_on_load_missing
=
[
r
"h\.\d+\.attn\.masked_bias"
,
r
"lm_head\.weight"
]
_keys_to_ignore_on_load_missing
=
[
r
"h\.\d+\.attn\.masked_bias"
,
r
"lm_head\.weight"
]
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
:
ImageGPTConfig
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
num_labels
=
config
.
num_labels
self
.
transformer
=
ImageGPTModel
(
config
)
self
.
transformer
=
ImageGPTModel
(
config
)
...
@@ -1101,20 +1101,20 @@ class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
...
@@ -1101,20 +1101,20 @@ class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
@
replace_return_docstrings
(
output_type
=
SequenceClassifierOutputWithPast
,
config_class
=
_CONFIG_FOR_DOC
)
@
replace_return_docstrings
(
output_type
=
SequenceClassifierOutputWithPast
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
def
forward
(
self
,
self
,
input_ids
=
None
,
input_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
past_key_values
=
None
,
past_key_values
:
Optional
[
Tuple
[
Tuple
[
torch
.
Tensor
]]]
=
None
,
attention_mask
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
token_type_ids
=
None
,
token_type_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
=
None
,
position_ids
:
Optional
[
torch
.
Tensor
]
=
None
,
head_mask
=
None
,
head_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
inputs_embeds
=
None
,
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
labels
=
None
,
labels
:
Optional
[
torch
.
Tensor
]
=
None
,
use_cache
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
**
kwargs
,
**
kwargs
:
Any
,
):
)
->
Union
[
Tuple
,
SequenceClassifierOutputWithPast
]
:
r
"""
r
"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment