Unverified Commit 04c446f7 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make `ModelOutput` pickle-able (#8989)

parent 0d9e6ca9
...@@ -41,7 +41,7 @@ class BaseModelOutput(ModelOutput): ...@@ -41,7 +41,7 @@ class BaseModelOutput(ModelOutput):
heads. heads.
""" """
last_hidden_state: torch.FloatTensor last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
...@@ -71,7 +71,7 @@ class BaseModelOutputWithPooling(ModelOutput): ...@@ -71,7 +71,7 @@ class BaseModelOutputWithPooling(ModelOutput):
heads. heads.
""" """
last_hidden_state: torch.FloatTensor last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
...@@ -107,7 +107,7 @@ class BaseModelOutputWithPast(ModelOutput): ...@@ -107,7 +107,7 @@ class BaseModelOutputWithPast(ModelOutput):
heads. heads.
""" """
last_hidden_state: torch.FloatTensor last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
...@@ -140,7 +140,7 @@ class BaseModelOutputWithCrossAttentions(ModelOutput): ...@@ -140,7 +140,7 @@ class BaseModelOutputWithCrossAttentions(ModelOutput):
weighted average in the cross-attention heads. weighted average in the cross-attention heads.
""" """
last_hidden_state: torch.FloatTensor last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
...@@ -177,7 +177,7 @@ class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput): ...@@ -177,7 +177,7 @@ class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
weighted average in the cross-attention heads. weighted average in the cross-attention heads.
""" """
last_hidden_state: torch.FloatTensor last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
...@@ -220,7 +220,7 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): ...@@ -220,7 +220,7 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
weighted average in the cross-attention heads. weighted average in the cross-attention heads.
""" """
last_hidden_state: torch.FloatTensor last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
...@@ -277,7 +277,7 @@ class Seq2SeqModelOutput(ModelOutput): ...@@ -277,7 +277,7 @@ class Seq2SeqModelOutput(ModelOutput):
self-attention heads. self-attention heads.
""" """
last_hidden_state: torch.FloatTensor last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
...@@ -310,7 +310,7 @@ class CausalLMOutput(ModelOutput): ...@@ -310,7 +310,7 @@ class CausalLMOutput(ModelOutput):
heads. heads.
""" """
loss: Optional[torch.FloatTensor] loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
...@@ -381,7 +381,7 @@ class CausalLMOutputWithCrossAttentions(ModelOutput): ...@@ -381,7 +381,7 @@ class CausalLMOutputWithCrossAttentions(ModelOutput):
cross-attention heads. cross-attention heads.
""" """
loss: Optional[torch.FloatTensor] loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment