Commit c22545aa authored by thomwolf's avatar thomwolf
Browse files

fix xlm torchscript

parent 3b23a846
...@@ -536,7 +536,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -536,7 +536,7 @@ class XLMModel(XLMPreTrainedModel):
# positions # positions
if positions is None: if positions is None:
positions = input_ids.new(slen).long() positions = input_ids.new((slen,)).long()
positions = torch.arange(slen, out=positions).unsqueeze(0) positions = torch.arange(slen, out=positions).unsqueeze(0)
else: else:
assert positions.size() == (bs, slen) # (slen, bs) assert positions.size() == (bs, slen) # (slen, bs)
...@@ -585,17 +585,17 @@ class XLMModel(XLMPreTrainedModel): ...@@ -585,17 +585,17 @@ class XLMModel(XLMPreTrainedModel):
tensor *= mask.unsqueeze(-1).to(tensor.dtype) tensor *= mask.unsqueeze(-1).to(tensor.dtype)
# transformer layers # transformer layers
hidden_states = [] hidden_states = ()
attentions = [] attentions = ()
for i in range(self.n_layers): for i in range(self.n_layers):
if self.output_hidden_states: if self.output_hidden_states:
hidden_states.append(tensor) hidden_states = hidden_states + (tensor,)
# self attention # self attention
attn_outputs = self.attentions[i](tensor, attn_mask, cache=cache, head_mask=head_mask[i]) attn_outputs = self.attentions[i](tensor, attn_mask, cache=cache, head_mask=head_mask[i])
attn = attn_outputs[0] attn = attn_outputs[0]
if self.output_attentions: if self.output_attentions:
attentions.append(attn_outputs[1]) attentions = attentions + (attn_outputs[1],)
attn = F.dropout(attn, p=self.dropout, training=self.training) attn = F.dropout(attn, p=self.dropout, training=self.training)
tensor = tensor + attn tensor = tensor + attn
tensor = self.layer_norm1[i](tensor) tensor = self.layer_norm1[i](tensor)
...@@ -614,7 +614,7 @@ class XLMModel(XLMPreTrainedModel): ...@@ -614,7 +614,7 @@ class XLMModel(XLMPreTrainedModel):
# Add last hidden state # Add last hidden state
if self.output_hidden_states: if self.output_hidden_states:
hidden_states.append(tensor) hidden_states = hidden_states + (tensor,)
# update cache length # update cache length
if cache is not None: if cache is not None:
...@@ -623,11 +623,11 @@ class XLMModel(XLMPreTrainedModel): ...@@ -623,11 +623,11 @@ class XLMModel(XLMPreTrainedModel):
# move back sequence length to dimension 0 # move back sequence length to dimension 0
# tensor = tensor.transpose(0, 1) # tensor = tensor.transpose(0, 1)
outputs = [tensor] outputs = (tensor,)
if self.output_hidden_states: if self.output_hidden_states:
outputs.append(hidden_states) outputs = outputs + (hidden_states,)
if self.output_attentions: if self.output_attentions:
outputs.append(attentions) outputs = outputs + (attentions,)
return outputs # outputs, (hidden_states), (attentions) return outputs # outputs, (hidden_states), (attentions)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment