Unverified Commit c53a6eae authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`RWKV`] Add note in doc on `RwkvStoppingCriteria` (#25055)

* Add note in doc on `RwkvStoppingCriteria`

* give some breathing space to the code
parent d2295708
...@@ -51,6 +51,24 @@ output_two = outputs.last_hidden_state ...@@ -51,6 +51,24 @@ output_two = outputs.last_hidden_state
torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5) torch.allclose(torch.cat([output_one, output_two], dim=1), output_whole, atol=1e-5)
``` ```
If you want to make sure the model stops generating when `'\n\n'` is detected, we recommend using the following stopping criteria:
```python
from transformers import StoppingCriteria
class RwkvStoppingCriteria(StoppingCriteria):
def __init__(self, eos_sequence = [187,187], eos_token_id = 537):
self.eos_sequence = eos_sequence
self.eos_token_id = eos_token_id
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
last_2_ids = input_ids[:,-2:].tolist()
return self.eos_sequence in last_2_ids
output = model.generate(inputs["input_ids"], max_new_tokens=64, stopping_criteria = [RwkvStoppingCriteria()])
```
## RwkvConfig ## RwkvConfig
[[autodoc]] RwkvConfig [[autodoc]] RwkvConfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment