Unverified Commit 04c446f7 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make `ModelOutput` pickle-able (#8989)

parent 0d9e6ca9
......@@ -41,7 +41,7 @@ class BaseModelOutput(ModelOutput):
heads.
"""
last_hidden_state: torch.FloatTensor
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
......@@ -71,7 +71,7 @@ class BaseModelOutputWithPooling(ModelOutput):
heads.
"""
last_hidden_state: torch.FloatTensor
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
......@@ -107,7 +107,7 @@ class BaseModelOutputWithPast(ModelOutput):
heads.
"""
last_hidden_state: torch.FloatTensor
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
......@@ -140,7 +140,7 @@ class BaseModelOutputWithCrossAttentions(ModelOutput):
weighted average in the cross-attention heads.
"""
last_hidden_state: torch.FloatTensor
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
......@@ -177,7 +177,7 @@ class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
weighted average in the cross-attention heads.
"""
last_hidden_state: torch.FloatTensor
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
......@@ -220,7 +220,7 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
weighted average in the cross-attention heads.
"""
last_hidden_state: torch.FloatTensor
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
......@@ -277,7 +277,7 @@ class Seq2SeqModelOutput(ModelOutput):
self-attention heads.
"""
last_hidden_state: torch.FloatTensor
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
......@@ -310,7 +310,7 @@ class CausalLMOutput(ModelOutput):
heads.
"""
loss: Optional[torch.FloatTensor]
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
......@@ -381,7 +381,7 @@ class CausalLMOutputWithCrossAttentions(ModelOutput):
cross-attention heads.
"""
loss: Optional[torch.FloatTensor]
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment