"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "755144694369460914385326de9adb027f167236"
Unverified Commit bdfd57d1 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Fix ImageGPT doc example (#24317)

* Fix ImageGPT doc example

* Update src/transformers/models/imagegpt/image_processing_imagegpt.py

* Fix types
parent 096f2cf1
......@@ -60,9 +60,9 @@ class ImageGPTImageProcessor(BaseImageProcessor):
(color clusters).
Args:
clusters (`np.ndarray`, *optional*):
The color clusters to use, as a `np.ndarray` of shape `(n_clusters, 3)` when color quantizing. Can be
overriden by `clusters` in `preprocess`.
clusters (`np.ndarray` or `List[List[int]]`, *optional*):
The color clusters to use, of shape `(n_clusters, 3)` when color quantizing. Can be overriden by `clusters`
in `preprocess`.
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's dimensions to `(size["height"], size["width"])`. Can be overridden by
`do_resize` in `preprocess`.
......@@ -82,7 +82,7 @@ class ImageGPTImageProcessor(BaseImageProcessor):
def __init__(
self,
# clusters is a first argument to maintain backwards compatibility with the old ImageGPTFeatureExtractor
clusters: Optional[np.ndarray] = None,
clusters: Optional[Union[List[List[int]], np.ndarray]] = None,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BILINEAR,
......@@ -93,7 +93,7 @@ class ImageGPTImageProcessor(BaseImageProcessor):
super().__init__(**kwargs)
size = size if size is not None else {"height": 256, "width": 256}
size = get_size_dict(size)
self.clusters = clusters
self.clusters = np.array(clusters) if clusters is not None else None
self.do_resize = do_resize
self.size = size
self.resample = resample
......@@ -154,7 +154,7 @@ class ImageGPTImageProcessor(BaseImageProcessor):
resample: PILImageResampling = None,
do_normalize: bool = None,
do_color_quantize: Optional[bool] = None,
clusters: Optional[Union[int, List[int]]] = None,
clusters: Optional[Union[List[List[int]], np.ndarray]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
data_format: Optional[Union[str, ChannelDimension]] = ChannelDimension.FIRST,
**kwargs,
......@@ -176,7 +176,7 @@ class ImageGPTImageProcessor(BaseImageProcessor):
Whether to normalize the image
do_color_quantize (`bool`, *optional*, defaults to `self.do_color_quantize`):
Whether to color quantize the image.
clusters (`np.ndarray`, *optional*, defaults to `self.clusters`):
clusters (`np.ndarray` or `List[List[int]]`, *optional*, defaults to `self.clusters`):
Clusters used to quantize the image of shape `(n_clusters, 3)`. Only has an effect if
`do_color_quantize` is set to `True`.
return_tensors (`str` or `TensorType`, *optional*):
......@@ -199,6 +199,7 @@ class ImageGPTImageProcessor(BaseImageProcessor):
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
do_color_quantize = do_color_quantize if do_color_quantize is not None else self.do_color_quantize
clusters = clusters if clusters is not None else self.clusters
clusters = np.array(clusters)
images = make_list_of_images(images)
......@@ -227,7 +228,6 @@ class ImageGPTImageProcessor(BaseImageProcessor):
images = [to_channel_dimension_format(image, ChannelDimension.LAST) for image in images]
# color quantize from (batch_size, height, width, 3) to (batch_size, height, width)
images = np.array(images)
clusters = np.array(clusters)
images = color_quantize(images, clusters).reshape(images.shape[:-1])
# flatten to (batch_size, height*width)
......
......@@ -983,9 +983,9 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
>>> model.to(device)
>>> # unconditional generation of 8 images
>>> batch_size = 8
>>> batch_size = 4
>>> context = torch.full((batch_size, 1), model.config.vocab_size - 1) # initialize with SOS token
>>> context = torch.tensor(context).to(device)
>>> context = context.to(device)
>>> output = model.generate(
... input_ids=context, max_length=model.config.n_positions + 1, temperature=1.0, do_sample=True, top_k=40
... )
......
......@@ -102,6 +102,7 @@ src/transformers/models/groupvit/modeling_groupvit.py
src/transformers/models/groupvit/modeling_tf_groupvit.py
src/transformers/models/hubert/modeling_hubert.py
src/transformers/models/imagegpt/configuration_imagegpt.py
src/transformers/models/imagegpt/modeling_imagegpt.py
src/transformers/models/layoutlm/configuration_layoutlm.py
src/transformers/models/layoutlm/modeling_layoutlm.py
src/transformers/models/layoutlm/modeling_tf_layoutlm.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment