Unverified Commit b25067d8 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

[Fix doc example] TFRagModel (#15187)



* fix doc example - NameError: name 'PATH' is not defined

* fix name 'TFRagModel' is not defined

* correct TFRagRagSequenceForGeneration

* fix name 'tf' is not defined

* fix style
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent dea563c9
...@@ -293,7 +293,9 @@ class TFRagPreTrainedModel(TFPreTrainedModel): ...@@ -293,7 +293,9 @@ class TFRagPreTrainedModel(TFPreTrainedModel):
>>> model.save_pretrained("./rag") >>> model.save_pretrained("./rag")
>>> # load retriever >>> # load retriever
>>> retriever = RagRetriever.from_pretrained(PATH, index_name="exact", use_dummy_dataset=True) >>> retriever = RagRetriever.from_pretrained(
... "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True
... )
>>> # load fine-tuned model with retriever >>> # load fine-tuned model with retriever
>>> model = TFRagModel.from_pretrained("./rag", retriever=retriever) >>> model = TFRagModel.from_pretrained("./rag", retriever=retriever)
```""" ```"""
...@@ -559,7 +561,7 @@ class TFRagModel(TFRagPreTrainedModel): ...@@ -559,7 +561,7 @@ class TFRagModel(TFRagPreTrainedModel):
Example: Example:
```python ```python
>>> from transformers import RagTokenizer, RagRetriever, RagModel >>> from transformers import RagTokenizer, RagRetriever, TFRagModel
>>> import torch >>> import torch
>>> tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base") >>> tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-base")
...@@ -939,6 +941,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss ...@@ -939,6 +941,7 @@ class TFRagTokenForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingLoss
Example: Example:
```python ```python
>>> import tensorflow as tf
>>> from transformers import RagTokenizer, RagRetriever, TFRagTokenForGeneration >>> from transformers import RagTokenizer, RagRetriever, TFRagTokenForGeneration
>>> tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") >>> tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
...@@ -1554,7 +1557,7 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL ...@@ -1554,7 +1557,7 @@ class TFRagSequenceForGeneration(TFRagPreTrainedModel, TFCausalLanguageModelingL
... "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True ... "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
... ) ... )
>>> # initialize with RagRetriever to do everything in one forward call >>> # initialize with RagRetriever to do everything in one forward call
>>> model = TFRagRagSequenceForGeneration.from_pretrained( >>> model = TFRagSequenceForGeneration.from_pretrained(
... "facebook/rag-sequence-nq", retriever=retriever, from_pt=True ... "facebook/rag-sequence-nq", retriever=retriever, from_pt=True
... ) ... )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment