Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
25156eb2
Unverified
Commit
25156eb2
authored
Nov 29, 2021
by
NielsRogge
Committed by
GitHub
Nov 29, 2021
Browse files
Rename ImageGPT (#14526)
* Rename * Add MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
parent
4ee0b755
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
36 additions
and
19 deletions
+36
-19
docs/source/model_doc/imagegpt.rst
docs/source/model_doc/imagegpt.rst
+2
-2
src/transformers/__init__.py
src/transformers/__init__.py
+4
-2
src/transformers/models/auto/__init__.py
src/transformers/models/auto/__init__.py
+2
-0
src/transformers/models/auto/modeling_auto.py
src/transformers/models/auto/modeling_auto.py
+10
-2
src/transformers/models/imagegpt/__init__.py
src/transformers/models/imagegpt/__init__.py
+2
-2
src/transformers/models/imagegpt/modeling_imagegpt.py
src/transformers/models/imagegpt/modeling_imagegpt.py
+3
-3
src/transformers/utils/dummy_pt_objects.py
src/transformers/utils/dummy_pt_objects.py
+4
-1
tests/test_modeling_common.py
tests/test_modeling_common.py
+2
-0
tests/test_modeling_imagegpt.py
tests/test_modeling_imagegpt.py
+7
-7
No files found.
docs/source/model_doc/imagegpt.rst
View file @
25156eb2
...
...
@@ -96,10 +96,10 @@ ImageGPTModel
:
members
:
forward
ImageGPTForCausal
LM
ImageGPTForCausal
ImageModeling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
..
autoclass
::
transformers
.
ImageGPTForCausal
LM
..
autoclass
::
transformers
.
ImageGPTForCausal
ImageModeling
:
members
:
forward
...
...
src/transformers/__init__.py
View file @
25156eb2
...
...
@@ -619,6 +619,7 @@ if is_torch_available():
_import_structure
[
"models.auto"
].
extend
(
[
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING"
,
"MODEL_FOR_CAUSAL_LM_MAPPING"
,
"MODEL_FOR_CTC_MAPPING"
,
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING"
,
...
...
@@ -977,7 +978,7 @@ if is_torch_available():
_import_structure
[
"models.imagegpt"
].
extend
(
[
"IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"ImageGPTForCausal
LM
"
,
"ImageGPTForCausal
ImageModeling
"
,
"ImageGPTForImageClassification"
,
"ImageGPTModel"
,
"ImageGPTPreTrainedModel"
,
...
...
@@ -2521,6 +2522,7 @@ if TYPE_CHECKING:
)
from
.models.auto
import
(
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
,
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_CTC_MAPPING
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
...
...
@@ -2823,7 +2825,7 @@ if TYPE_CHECKING:
)
from
.models.imagegpt
import
(
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST
,
ImageGPTForCausal
LM
,
ImageGPTForCausal
ImageModeling
,
ImageGPTForImageClassification
,
ImageGPTModel
,
ImageGPTPreTrainedModel
,
...
...
src/transformers/models/auto/__init__.py
View file @
25156eb2
...
...
@@ -32,6 +32,7 @@ _import_structure = {
if
is_torch_available
():
_import_structure
[
"modeling_auto"
]
=
[
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING"
,
"MODEL_FOR_CAUSAL_LM_MAPPING"
,
"MODEL_FOR_CTC_MAPPING"
,
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING"
,
...
...
@@ -137,6 +138,7 @@ if TYPE_CHECKING:
if
is_torch_available
():
from
.modeling_auto
import
(
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
,
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_CTC_MAPPING
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
...
...
src/transformers/models/auto/modeling_auto.py
View file @
25156eb2
...
...
@@ -147,7 +147,6 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
MODEL_WITH_LM_HEAD_MAPPING_NAMES
=
OrderedDict
(
[
# Model with LM heads mapping
(
"imagegpt"
,
"ImageGPTForCausalLM"
),
(
"qdqbert"
,
"QDQBertForMaskedLM"
),
(
"fnet"
,
"FNetForMaskedLM"
),
(
"gptj"
,
"GPTJForCausalLM"
),
...
...
@@ -199,7 +198,6 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
=
OrderedDict
(
[
# Model for Causal LM mapping
(
"imagegpt"
,
"ImageGPTForCausalLM"
),
(
"qdqbert"
,
"QDQBertLMHeadModel"
),
(
"trocr"
,
"TrOCRForCausalLM"
),
(
"gptj"
,
"GPTJForCausalLM"
),
...
...
@@ -233,6 +231,13 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
]
)
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
=
OrderedDict
(
# Model for Causal Image Modeling mapping
[
(
"imagegpt"
,
"ImageGPTForCausalImageModeling"
),
]
)
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
=
OrderedDict
(
[
# Model for Image Classification mapping
...
...
@@ -524,6 +529,9 @@ MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_PRETRAINING_MAPPING_NAMES
)
MODEL_WITH_LM_HEAD_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_WITH_LM_HEAD_MAPPING_NAMES
)
MODEL_FOR_CAUSAL_LM_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
)
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
...
...
src/transformers/models/imagegpt/__init__.py
View file @
25156eb2
...
...
@@ -31,7 +31,7 @@ if is_vision_available():
if
is_torch_available
():
_import_structure
[
"modeling_imagegpt"
]
=
[
"IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"ImageGPTForCausal
LM
"
,
"ImageGPTForCausal
ImageModeling
"
,
"ImageGPTForImageClassification"
,
"ImageGPTModel"
,
"ImageGPTPreTrainedModel"
,
...
...
@@ -48,7 +48,7 @@ if TYPE_CHECKING:
if
is_torch_available
():
from
.modeling_imagegpt
import
(
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST
,
ImageGPTForCausal
LM
,
ImageGPTForCausal
ImageModeling
,
ImageGPTForImageClassification
,
ImageGPTModel
,
ImageGPTPreTrainedModel
,
...
...
src/transformers/models/imagegpt/modeling_imagegpt.py
View file @
25156eb2
...
...
@@ -881,7 +881,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
"""
,
IMAGEGPT_START_DOCSTRING
,
)
class
ImageGPTForCausal
LM
(
ImageGPTPreTrainedModel
):
class
ImageGPTForCausal
ImageModeling
(
ImageGPTPreTrainedModel
):
_keys_to_ignore_on_load_missing
=
[
r
"attn.masked_bias"
,
r
"attn.bias"
,
r
"lm_head.weight"
]
def
__init__
(
self
,
config
):
...
...
@@ -958,13 +958,13 @@ class ImageGPTForCausalLM(ImageGPTPreTrainedModel):
Examples::
>>> from transformers import ImageGPTFeatureExtractor, ImageGPTForCausal
LM
>>> from transformers import ImageGPTFeatureExtractor, ImageGPTForCausal
ImageModeling
>>> import torch
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> feature_extractor = ImageGPTFeatureExtractor.from_pretrained('openai/imagegpt-small')
>>> model = ImageGPTForCausal
LM
.from_pretrained('openai/imagegpt-small')
>>> model = ImageGPTForCausal
ImageModeling
.from_pretrained('openai/imagegpt-small')
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> model.to(device)
...
...
src/transformers/utils/dummy_pt_objects.py
View file @
25156eb2
...
...
@@ -341,6 +341,9 @@ def load_tf_weights_in_albert(*args, **kwargs):
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
=
None
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
=
None
MODEL_FOR_CAUSAL_LM_MAPPING
=
None
...
...
@@ -2661,7 +2664,7 @@ class IBertPreTrainedModel:
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST
=
None
class
ImageGPTForCausal
LM
:
class
ImageGPTForCausal
ImageModeling
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
...
...
tests/test_modeling_common.py
View file @
25156eb2
...
...
@@ -61,6 +61,7 @@ if is_torch_available():
from
transformers
import
(
BERT_PRETRAINED_MODEL_ARCHIVE_LIST
,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
,
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
MODEL_FOR_MASKED_LM_MAPPING
,
...
...
@@ -150,6 +151,7 @@ class ModelTesterMixin:
elif
model_class
in
[
*
get_values
(
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
),
*
get_values
(
MODEL_FOR_CAUSAL_LM_MAPPING
),
*
get_values
(
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
),
*
get_values
(
MODEL_FOR_MASKED_LM_MAPPING
),
*
get_values
(
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
),
]:
...
...
tests/test_modeling_imagegpt.py
View file @
25156eb2
...
...
@@ -34,7 +34,7 @@ if is_torch_available():
from
transformers
import
(
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST
,
ImageGPTForCausal
LM
,
ImageGPTForCausal
ImageModeling
,
ImageGPTForImageClassification
,
ImageGPTModel
,
)
...
...
@@ -207,14 +207,14 @@ class ImageGPTModelTester:
self
.
parent
.
assertEqual
(
len
(
result
.
past_key_values
),
config
.
n_layer
)
def
create_and_check_lm_head_model
(
self
,
config
,
pixel_values
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
ImageGPTForCausal
LM
(
config
)
model
=
ImageGPTForCausal
ImageModeling
(
config
)
model
.
to
(
torch_device
)
model
.
eval
()
labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
-
1
)
result
=
model
(
pixel_values
,
token_type_ids
=
token_type_ids
,
labels
=
labels
)
self
.
parent
.
assertEqual
(
result
.
loss
.
shape
,
())
# ImageGPTForCausal
LM
doens't have tied input- and output embeddings
# ImageGPTForCausal
ImageModeling
doens't have tied input- and output embeddings
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
-
1
))
def
create_and_check_imagegpt_for_image_classification
(
...
...
@@ -255,9 +255,9 @@ class ImageGPTModelTester:
class
ImageGPTModelTest
(
ModelTesterMixin
,
GenerationTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
(
ImageGPTForCausal
LM
,
ImageGPTForImageClassification
,
ImageGPTModel
)
if
is_torch_available
()
else
()
(
ImageGPTForCausal
ImageModeling
,
ImageGPTForImageClassification
,
ImageGPTModel
)
if
is_torch_available
()
else
()
)
all_generative_model_classes
=
(
ImageGPTForCausal
LM
,)
if
is_torch_available
()
else
()
all_generative_model_classes
=
(
ImageGPTForCausal
ImageModeling
,)
if
is_torch_available
()
else
()
test_missing_keys
=
False
input_name
=
"pixel_values"
...
...
@@ -273,7 +273,7 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCa
return
inputs_dict
# we overwrite the _check_scores method of GenerationTesterMixin, as ImageGPTForCausal
LM
doesn't have tied input- and output embeddings
# we overwrite the _check_scores method of GenerationTesterMixin, as ImageGPTForCausal
ImageModeling
doesn't have tied input- and output embeddings
def
_check_scores
(
self
,
batch_size
,
scores
,
length
,
config
):
expected_shape
=
(
batch_size
,
config
.
vocab_size
-
1
)
self
.
assertIsInstance
(
scores
,
tuple
)
...
...
@@ -519,7 +519,7 @@ class ImageGPTModelIntegrationTest(unittest.TestCase):
@
slow
def
test_inference_causal_lm_head
(
self
):
model
=
ImageGPTForCausal
LM
.
from_pretrained
(
"openai/imagegpt-small"
).
to
(
torch_device
)
model
=
ImageGPTForCausal
ImageModeling
.
from_pretrained
(
"openai/imagegpt-small"
).
to
(
torch_device
)
feature_extractor
=
self
.
default_feature_extractor
image
=
prepare_img
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment