"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "05ed569c79b59e69270a2000b95383a09d7f16fd"
Unverified Commit e5bc438c authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

[Fix doc example] Fix 2 PyTorch Vilt docstring examples (#16076)



* fix 2 pytorch vilt docstring examples

* add vilt to doctest list file

* remove device
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent bcaf5660
...@@ -933,6 +933,7 @@ class ViltForMaskedLM(ViltPreTrainedModel): ...@@ -933,6 +933,7 @@ class ViltForMaskedLM(ViltPreTrainedModel):
>>> import requests >>> import requests
>>> from PIL import Image >>> from PIL import Image
>>> import re >>> import re
>>> import torch
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
...@@ -954,9 +955,9 @@ class ViltForMaskedLM(ViltPreTrainedModel): ...@@ -954,9 +955,9 @@ class ViltForMaskedLM(ViltPreTrainedModel):
>>> with torch.no_grad(): >>> with torch.no_grad():
... for i in range(tl): ... for i in range(tl):
... encoded = processor.tokenizer(inferred_token) ... encoded = processor.tokenizer(inferred_token)
... input_ids = torch.tensor(encoded.input_ids).to(device) ... input_ids = torch.tensor(encoded.input_ids)
... encoded = encoded["input_ids"][0][1:-1] ... encoded = encoded["input_ids"][0][1:-1]
... outputs = model(input_ids=input_ids, pixel_values=pixel_values) ... outputs = model(input_ids=input_ids, pixel_values=encoding.pixel_values)
... mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size) ... mlm_logits = outputs.logits[0] # shape (seq_len, vocab_size)
... # only take into account text features (minus CLS and SEP token) ... # only take into account text features (minus CLS and SEP token)
... mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :] ... mlm_logits = mlm_logits[1 : input_ids.shape[1] - 1, :]
...@@ -969,7 +970,8 @@ class ViltForMaskedLM(ViltPreTrainedModel): ...@@ -969,7 +970,8 @@ class ViltForMaskedLM(ViltPreTrainedModel):
>>> selected_token = "" >>> selected_token = ""
>>> encoded = processor.tokenizer(inferred_token) >>> encoded = processor.tokenizer(inferred_token)
>>> processor.decode(encoded.input_ids[0], skip_special_tokens=True) >>> output = processor.decode(encoded.input_ids[0], skip_special_tokens=True)
>>> print(output)
a bunch of cats laying on a couch. a bunch of cats laying on a couch.
```""" ```"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
...@@ -1215,12 +1217,10 @@ class ViltForImageAndTextRetrieval(ViltPreTrainedModel): ...@@ -1215,12 +1217,10 @@ class ViltForImageAndTextRetrieval(ViltPreTrainedModel):
>>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-coco") >>> processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-coco")
>>> model = ViltForImageAndTextRetrieval.from_pretrained("dandelin/vilt-b32-finetuned-coco") >>> model = ViltForImageAndTextRetrieval.from_pretrained("dandelin/vilt-b32-finetuned-coco")
>>> # prepare inputs
>>> encoding = processor(image, text, return_tensors="pt")
>>> # forward pass >>> # forward pass
>>> scores = dict() >>> scores = dict()
>>> for text in texts: >>> for text in texts:
... # prepare inputs
... encoding = processor(image, text, return_tensors="pt") ... encoding = processor(image, text, return_tensors="pt")
... outputs = model(**encoding) ... outputs = model(**encoding)
... scores[text] = outputs.logits[0, :].item() ... scores[text] = outputs.logits[0, :].item()
......
...@@ -18,6 +18,7 @@ src/transformers/models/swin/modeling_swin.py ...@@ -18,6 +18,7 @@ src/transformers/models/swin/modeling_swin.py
src/transformers/models/convnext/modeling_convnext.py src/transformers/models/convnext/modeling_convnext.py
src/transformers/models/poolformer/modeling_poolformer.py src/transformers/models/poolformer/modeling_poolformer.py
src/transformers/models/vit_mae/modeling_vit_mae.py src/transformers/models/vit_mae/modeling_vit_mae.py
src/transformers/models/vilt/modeling_vilt.py
src/transformers/models/van/modeling_van.py src/transformers/models/van/modeling_van.py
src/transformers/models/segformer/modeling_segformer.py src/transformers/models/segformer/modeling_segformer.py
src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment