Unverified Commit 21f7e81b authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Make `RwkvModel` accept `attention_mask` but discard it internally (#23442)



* fix

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent cf432008
......@@ -565,6 +565,15 @@ RWKV_INPUTS_DOCSTRING = r"""
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
This is currently not used by `RwkvModel`, but will be supported in the future.
[What are attention masks?](../glossary#attention-mask)
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
......@@ -617,6 +626,7 @@ class RwkvModel(RwkvPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, # noqa
inputs_embeds: Optional[torch.FloatTensor] = None,
state: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
......@@ -750,7 +760,7 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None, # noqa
inputs_embeds: Optional[torch.FloatTensor] = None,
state: Optional[List[torch.FloatTensor]] = None,
labels: Optional[torch.LongTensor] = None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment