Unverified Commit 120176ea authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Longformer] Fix longformer documentation (#7016)

* fix longformer

* allow position ids to not be initialized
parent 5c4eb4b1
...@@ -795,6 +795,7 @@ class LongformerPreTrainedModel(PreTrainedModel): ...@@ -795,6 +795,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
config_class = LongformerConfig config_class = LongformerConfig
base_model_prefix = "longformer" base_model_prefix = "longformer"
authorized_missing_keys = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
""" Initialize the weights """ """ Initialize the weights """
...@@ -1019,11 +1020,13 @@ class LongformerModel(LongformerPreTrainedModel): ...@@ -1019,11 +1020,13 @@ class LongformerModel(LongformerPreTrainedModel):
>>> # Attention mask values -- 0: no attention, 1: local attention, 2: global attention >>> # Attention mask values -- 0: no attention, 1: local attention, 2: global attention
>>> attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention >>> attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention
>>> attention_mask[:, [1, 4, 21,]] = 2 # Set global attention based on the task. For example, >>> global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to global attention to be deactivated for all tokens
>>> global_attention_mask[:, [1, 4, 21,]] = 1 # Set global attention to random tokens for the sake of this example
... # Usually, set global attention based on the task. For example,
... # classification: the <s> token ... # classification: the <s> token
... # QA: question tokens ... # QA: question tokens
... # LM: potentially on the beginning of sentences and paragraphs ... # LM: potentially on the beginning of sentences and paragraphs
>>> outputs = model(input_ids, attention_mask=attention_mask) >>> outputs = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask)
>>> sequence_output = outputs.last_hidden_state >>> sequence_output = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output >>> pooled_output = outputs.pooler_output
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment