Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
25156eb2
Unverified
Commit
25156eb2
authored
Nov 29, 2021
by
NielsRogge
Committed by
GitHub
Nov 29, 2021
Browse files
Rename ImageGPT (#14526)
* Rename * Add MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
parent
4ee0b755
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
36 additions
and
19 deletions
+36
-19
docs/source/model_doc/imagegpt.rst
docs/source/model_doc/imagegpt.rst
+2
-2
src/transformers/__init__.py
src/transformers/__init__.py
+4
-2
src/transformers/models/auto/__init__.py
src/transformers/models/auto/__init__.py
+2
-0
src/transformers/models/auto/modeling_auto.py
src/transformers/models/auto/modeling_auto.py
+10
-2
src/transformers/models/imagegpt/__init__.py
src/transformers/models/imagegpt/__init__.py
+2
-2
src/transformers/models/imagegpt/modeling_imagegpt.py
src/transformers/models/imagegpt/modeling_imagegpt.py
+3
-3
src/transformers/utils/dummy_pt_objects.py
src/transformers/utils/dummy_pt_objects.py
+4
-1
tests/test_modeling_common.py
tests/test_modeling_common.py
+2
-0
tests/test_modeling_imagegpt.py
tests/test_modeling_imagegpt.py
+7
-7
No files found.
docs/source/model_doc/imagegpt.rst
View file @
25156eb2
...
@@ -96,10 +96,10 @@ ImageGPTModel
...
@@ -96,10 +96,10 @@ ImageGPTModel
:
members
:
forward
:
members
:
forward
ImageGPTForCausal
LM
ImageGPTForCausal
ImageModeling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
..
autoclass
::
transformers
.
ImageGPTForCausal
LM
..
autoclass
::
transformers
.
ImageGPTForCausal
ImageModeling
:
members
:
forward
:
members
:
forward
...
...
src/transformers/__init__.py
View file @
25156eb2
...
@@ -619,6 +619,7 @@ if is_torch_available():
...
@@ -619,6 +619,7 @@ if is_torch_available():
_import_structure
[
"models.auto"
].
extend
(
_import_structure
[
"models.auto"
].
extend
(
[
[
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING"
,
"MODEL_FOR_CAUSAL_LM_MAPPING"
,
"MODEL_FOR_CAUSAL_LM_MAPPING"
,
"MODEL_FOR_CTC_MAPPING"
,
"MODEL_FOR_CTC_MAPPING"
,
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING"
,
...
@@ -977,7 +978,7 @@ if is_torch_available():
...
@@ -977,7 +978,7 @@ if is_torch_available():
_import_structure
[
"models.imagegpt"
].
extend
(
_import_structure
[
"models.imagegpt"
].
extend
(
[
[
"IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"ImageGPTForCausal
LM
"
,
"ImageGPTForCausal
ImageModeling
"
,
"ImageGPTForImageClassification"
,
"ImageGPTForImageClassification"
,
"ImageGPTModel"
,
"ImageGPTModel"
,
"ImageGPTPreTrainedModel"
,
"ImageGPTPreTrainedModel"
,
...
@@ -2521,6 +2522,7 @@ if TYPE_CHECKING:
...
@@ -2521,6 +2522,7 @@ if TYPE_CHECKING:
)
)
from
.models.auto
import
(
from
.models.auto
import
(
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
,
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
,
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_CTC_MAPPING
,
MODEL_FOR_CTC_MAPPING
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
...
@@ -2823,7 +2825,7 @@ if TYPE_CHECKING:
...
@@ -2823,7 +2825,7 @@ if TYPE_CHECKING:
)
)
from
.models.imagegpt
import
(
from
.models.imagegpt
import
(
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST
,
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST
,
ImageGPTForCausal
LM
,
ImageGPTForCausal
ImageModeling
,
ImageGPTForImageClassification
,
ImageGPTForImageClassification
,
ImageGPTModel
,
ImageGPTModel
,
ImageGPTPreTrainedModel
,
ImageGPTPreTrainedModel
,
...
...
src/transformers/models/auto/__init__.py
View file @
25156eb2
...
@@ -32,6 +32,7 @@ _import_structure = {
...
@@ -32,6 +32,7 @@ _import_structure = {
if
is_torch_available
():
if
is_torch_available
():
_import_structure
[
"modeling_auto"
]
=
[
_import_structure
[
"modeling_auto"
]
=
[
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING"
,
"MODEL_FOR_CAUSAL_LM_MAPPING"
,
"MODEL_FOR_CAUSAL_LM_MAPPING"
,
"MODEL_FOR_CTC_MAPPING"
,
"MODEL_FOR_CTC_MAPPING"
,
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING"
,
"MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING"
,
...
@@ -137,6 +138,7 @@ if TYPE_CHECKING:
...
@@ -137,6 +138,7 @@ if TYPE_CHECKING:
if
is_torch_available
():
if
is_torch_available
():
from
.modeling_auto
import
(
from
.modeling_auto
import
(
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
,
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
,
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_CTC_MAPPING
,
MODEL_FOR_CTC_MAPPING
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
...
...
src/transformers/models/auto/modeling_auto.py
View file @
25156eb2
...
@@ -147,7 +147,6 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
...
@@ -147,7 +147,6 @@ MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
MODEL_WITH_LM_HEAD_MAPPING_NAMES
=
OrderedDict
(
MODEL_WITH_LM_HEAD_MAPPING_NAMES
=
OrderedDict
(
[
[
# Model with LM heads mapping
# Model with LM heads mapping
(
"imagegpt"
,
"ImageGPTForCausalLM"
),
(
"qdqbert"
,
"QDQBertForMaskedLM"
),
(
"qdqbert"
,
"QDQBertForMaskedLM"
),
(
"fnet"
,
"FNetForMaskedLM"
),
(
"fnet"
,
"FNetForMaskedLM"
),
(
"gptj"
,
"GPTJForCausalLM"
),
(
"gptj"
,
"GPTJForCausalLM"
),
...
@@ -199,7 +198,6 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
...
@@ -199,7 +198,6 @@ MODEL_WITH_LM_HEAD_MAPPING_NAMES = OrderedDict(
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
=
OrderedDict
(
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
=
OrderedDict
(
[
[
# Model for Causal LM mapping
# Model for Causal LM mapping
(
"imagegpt"
,
"ImageGPTForCausalLM"
),
(
"qdqbert"
,
"QDQBertLMHeadModel"
),
(
"qdqbert"
,
"QDQBertLMHeadModel"
),
(
"trocr"
,
"TrOCRForCausalLM"
),
(
"trocr"
,
"TrOCRForCausalLM"
),
(
"gptj"
,
"GPTJForCausalLM"
),
(
"gptj"
,
"GPTJForCausalLM"
),
...
@@ -233,6 +231,13 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
...
@@ -233,6 +231,13 @@ MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict(
]
]
)
)
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
=
OrderedDict
(
# Model for Causal Image Modeling mapping
[
(
"imagegpt"
,
"ImageGPTForCausalImageModeling"
),
]
)
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
=
OrderedDict
(
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
=
OrderedDict
(
[
[
# Model for Image Classification mapping
# Model for Image Classification mapping
...
@@ -524,6 +529,9 @@ MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
...
@@ -524,6 +529,9 @@ MODEL_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, MODEL_MAPPING_NAMES)
MODEL_FOR_PRETRAINING_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_PRETRAINING_MAPPING_NAMES
)
MODEL_FOR_PRETRAINING_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_PRETRAINING_MAPPING_NAMES
)
MODEL_WITH_LM_HEAD_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_WITH_LM_HEAD_MAPPING_NAMES
)
MODEL_WITH_LM_HEAD_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_WITH_LM_HEAD_MAPPING_NAMES
)
MODEL_FOR_CAUSAL_LM_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
MODEL_FOR_CAUSAL_LM_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING_NAMES
)
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
=
_LazyAutoMapping
(
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
=
_LazyAutoMapping
(
CONFIG_MAPPING_NAMES
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
CONFIG_MAPPING_NAMES
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES
)
)
...
...
src/transformers/models/imagegpt/__init__.py
View file @
25156eb2
...
@@ -31,7 +31,7 @@ if is_vision_available():
...
@@ -31,7 +31,7 @@ if is_vision_available():
if
is_torch_available
():
if
is_torch_available
():
_import_structure
[
"modeling_imagegpt"
]
=
[
_import_structure
[
"modeling_imagegpt"
]
=
[
"IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST"
,
"ImageGPTForCausal
LM
"
,
"ImageGPTForCausal
ImageModeling
"
,
"ImageGPTForImageClassification"
,
"ImageGPTForImageClassification"
,
"ImageGPTModel"
,
"ImageGPTModel"
,
"ImageGPTPreTrainedModel"
,
"ImageGPTPreTrainedModel"
,
...
@@ -48,7 +48,7 @@ if TYPE_CHECKING:
...
@@ -48,7 +48,7 @@ if TYPE_CHECKING:
if
is_torch_available
():
if
is_torch_available
():
from
.modeling_imagegpt
import
(
from
.modeling_imagegpt
import
(
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST
,
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST
,
ImageGPTForCausal
LM
,
ImageGPTForCausal
ImageModeling
,
ImageGPTForImageClassification
,
ImageGPTForImageClassification
,
ImageGPTModel
,
ImageGPTModel
,
ImageGPTPreTrainedModel
,
ImageGPTPreTrainedModel
,
...
...
src/transformers/models/imagegpt/modeling_imagegpt.py
View file @
25156eb2
...
@@ -881,7 +881,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
...
@@ -881,7 +881,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
"""
,
"""
,
IMAGEGPT_START_DOCSTRING
,
IMAGEGPT_START_DOCSTRING
,
)
)
class
ImageGPTForCausal
LM
(
ImageGPTPreTrainedModel
):
class
ImageGPTForCausal
ImageModeling
(
ImageGPTPreTrainedModel
):
_keys_to_ignore_on_load_missing
=
[
r
"attn.masked_bias"
,
r
"attn.bias"
,
r
"lm_head.weight"
]
_keys_to_ignore_on_load_missing
=
[
r
"attn.masked_bias"
,
r
"attn.bias"
,
r
"lm_head.weight"
]
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
...
@@ -958,13 +958,13 @@ class ImageGPTForCausalLM(ImageGPTPreTrainedModel):
...
@@ -958,13 +958,13 @@ class ImageGPTForCausalLM(ImageGPTPreTrainedModel):
Examples::
Examples::
>>> from transformers import ImageGPTFeatureExtractor, ImageGPTForCausal
LM
>>> from transformers import ImageGPTFeatureExtractor, ImageGPTForCausal
ImageModeling
>>> import torch
>>> import torch
>>> import matplotlib.pyplot as plt
>>> import matplotlib.pyplot as plt
>>> import numpy as np
>>> import numpy as np
>>> feature_extractor = ImageGPTFeatureExtractor.from_pretrained('openai/imagegpt-small')
>>> feature_extractor = ImageGPTFeatureExtractor.from_pretrained('openai/imagegpt-small')
>>> model = ImageGPTForCausal
LM
.from_pretrained('openai/imagegpt-small')
>>> model = ImageGPTForCausal
ImageModeling
.from_pretrained('openai/imagegpt-small')
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
>>> model.to(device)
>>> model.to(device)
...
...
src/transformers/utils/dummy_pt_objects.py
View file @
25156eb2
...
@@ -341,6 +341,9 @@ def load_tf_weights_in_albert(*args, **kwargs):
...
@@ -341,6 +341,9 @@ def load_tf_weights_in_albert(*args, **kwargs):
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
=
None
MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
=
None
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
=
None
MODEL_FOR_CAUSAL_LM_MAPPING
=
None
MODEL_FOR_CAUSAL_LM_MAPPING
=
None
...
@@ -2661,7 +2664,7 @@ class IBertPreTrainedModel:
...
@@ -2661,7 +2664,7 @@ class IBertPreTrainedModel:
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST
=
None
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST
=
None
class
ImageGPTForCausal
LM
:
class
ImageGPTForCausal
ImageModeling
:
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
requires_backends
(
self
,
[
"torch"
])
requires_backends
(
self
,
[
"torch"
])
...
...
tests/test_modeling_common.py
View file @
25156eb2
...
@@ -61,6 +61,7 @@ if is_torch_available():
...
@@ -61,6 +61,7 @@ if is_torch_available():
from
transformers
import
(
from
transformers
import
(
BERT_PRETRAINED_MODEL_ARCHIVE_LIST
,
BERT_PRETRAINED_MODEL_ARCHIVE_LIST
,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
,
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_CAUSAL_LM_MAPPING
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
,
MODEL_FOR_MASKED_LM_MAPPING
,
MODEL_FOR_MASKED_LM_MAPPING
,
...
@@ -150,6 +151,7 @@ class ModelTesterMixin:
...
@@ -150,6 +151,7 @@ class ModelTesterMixin:
elif
model_class
in
[
elif
model_class
in
[
*
get_values
(
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
),
*
get_values
(
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
),
*
get_values
(
MODEL_FOR_CAUSAL_LM_MAPPING
),
*
get_values
(
MODEL_FOR_CAUSAL_LM_MAPPING
),
*
get_values
(
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING
),
*
get_values
(
MODEL_FOR_MASKED_LM_MAPPING
),
*
get_values
(
MODEL_FOR_MASKED_LM_MAPPING
),
*
get_values
(
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
),
*
get_values
(
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
),
]:
]:
...
...
tests/test_modeling_imagegpt.py
View file @
25156eb2
...
@@ -34,7 +34,7 @@ if is_torch_available():
...
@@ -34,7 +34,7 @@ if is_torch_available():
from
transformers
import
(
from
transformers
import
(
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST
,
IMAGEGPT_PRETRAINED_MODEL_ARCHIVE_LIST
,
ImageGPTForCausal
LM
,
ImageGPTForCausal
ImageModeling
,
ImageGPTForImageClassification
,
ImageGPTForImageClassification
,
ImageGPTModel
,
ImageGPTModel
,
)
)
...
@@ -207,14 +207,14 @@ class ImageGPTModelTester:
...
@@ -207,14 +207,14 @@ class ImageGPTModelTester:
self
.
parent
.
assertEqual
(
len
(
result
.
past_key_values
),
config
.
n_layer
)
self
.
parent
.
assertEqual
(
len
(
result
.
past_key_values
),
config
.
n_layer
)
def
create_and_check_lm_head_model
(
self
,
config
,
pixel_values
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
def
create_and_check_lm_head_model
(
self
,
config
,
pixel_values
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
ImageGPTForCausal
LM
(
config
)
model
=
ImageGPTForCausal
ImageModeling
(
config
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
model
.
eval
()
labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
-
1
)
labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
-
1
)
result
=
model
(
pixel_values
,
token_type_ids
=
token_type_ids
,
labels
=
labels
)
result
=
model
(
pixel_values
,
token_type_ids
=
token_type_ids
,
labels
=
labels
)
self
.
parent
.
assertEqual
(
result
.
loss
.
shape
,
())
self
.
parent
.
assertEqual
(
result
.
loss
.
shape
,
())
# ImageGPTForCausal
LM
doens't have tied input- and output embeddings
# ImageGPTForCausal
ImageModeling
doens't have tied input- and output embeddings
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
-
1
))
self
.
parent
.
assertEqual
(
result
.
logits
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
-
1
))
def
create_and_check_imagegpt_for_image_classification
(
def
create_and_check_imagegpt_for_image_classification
(
...
@@ -255,9 +255,9 @@ class ImageGPTModelTester:
...
@@ -255,9 +255,9 @@ class ImageGPTModelTester:
class
ImageGPTModelTest
(
ModelTesterMixin
,
GenerationTesterMixin
,
unittest
.
TestCase
):
class
ImageGPTModelTest
(
ModelTesterMixin
,
GenerationTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
all_model_classes
=
(
(
ImageGPTForCausal
LM
,
ImageGPTForImageClassification
,
ImageGPTModel
)
if
is_torch_available
()
else
()
(
ImageGPTForCausal
ImageModeling
,
ImageGPTForImageClassification
,
ImageGPTModel
)
if
is_torch_available
()
else
()
)
)
all_generative_model_classes
=
(
ImageGPTForCausal
LM
,)
if
is_torch_available
()
else
()
all_generative_model_classes
=
(
ImageGPTForCausal
ImageModeling
,)
if
is_torch_available
()
else
()
test_missing_keys
=
False
test_missing_keys
=
False
input_name
=
"pixel_values"
input_name
=
"pixel_values"
...
@@ -273,7 +273,7 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCa
...
@@ -273,7 +273,7 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCa
return
inputs_dict
return
inputs_dict
# we overwrite the _check_scores method of GenerationTesterMixin, as ImageGPTForCausal
LM
doesn't have tied input- and output embeddings
# we overwrite the _check_scores method of GenerationTesterMixin, as ImageGPTForCausal
ImageModeling
doesn't have tied input- and output embeddings
def
_check_scores
(
self
,
batch_size
,
scores
,
length
,
config
):
def
_check_scores
(
self
,
batch_size
,
scores
,
length
,
config
):
expected_shape
=
(
batch_size
,
config
.
vocab_size
-
1
)
expected_shape
=
(
batch_size
,
config
.
vocab_size
-
1
)
self
.
assertIsInstance
(
scores
,
tuple
)
self
.
assertIsInstance
(
scores
,
tuple
)
...
@@ -519,7 +519,7 @@ class ImageGPTModelIntegrationTest(unittest.TestCase):
...
@@ -519,7 +519,7 @@ class ImageGPTModelIntegrationTest(unittest.TestCase):
@
slow
@
slow
def
test_inference_causal_lm_head
(
self
):
def
test_inference_causal_lm_head
(
self
):
model
=
ImageGPTForCausal
LM
.
from_pretrained
(
"openai/imagegpt-small"
).
to
(
torch_device
)
model
=
ImageGPTForCausal
ImageModeling
.
from_pretrained
(
"openai/imagegpt-small"
).
to
(
torch_device
)
feature_extractor
=
self
.
default_feature_extractor
feature_extractor
=
self
.
default_feature_extractor
image
=
prepare_img
()
image
=
prepare_img
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment