Commit be54b169 authored by LysandreJik's avatar LysandreJik
Browse files

GPT can be exported to TorchScript

parent d8e83de7
...@@ -355,7 +355,7 @@ class OpenAIGPTLMHead(nn.Module): ...@@ -355,7 +355,7 @@ class OpenAIGPTLMHead(nn.Module):
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True): def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
self.predict_special_tokens = predict_special_tokens self.predict_special_tokens = predict_special_tokens
embed_shape = model_embeddings_weights.shape embed_shape = model_embeddings_weights.shape
self.decoder.weight = model_embeddings_weights # Tied weights self.decoder.weight = nn.Parameter(model_embeddings_weights.clone()) # Tied weights
def forward(self, hidden_state): def forward(self, hidden_state):
lm_logits = self.decoder(hidden_state) lm_logits = self.decoder(hidden_state)
...@@ -579,26 +579,26 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel): ...@@ -579,26 +579,26 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
all_attentions = [] all_attentions = ()
all_hidden_states = [] all_hidden_states = ()
for i, block in enumerate(self.h): for i, block in enumerate(self.h):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states.append(hidden_states.view(*output_shape)) all_hidden_states += (hidden_states.view(*output_shape),)
outputs = block(hidden_states, head_mask[i]) outputs = block(hidden_states, head_mask[i])
hidden_states = outputs[0] hidden_states = outputs[0]
if self.output_attentions: if self.output_attentions:
all_attentions.append(outputs[1]) all_attentions += (outputs[1],)
# Add last layer # Add last layer
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states.append(hidden_states.view(*output_shape)) all_hidden_states += (hidden_states.view(*output_shape),)
outputs = [hidden_states.view(*output_shape)] outputs = (hidden_states.view(*output_shape),)
if self.output_hidden_states: if self.output_hidden_states:
outputs.append(all_hidden_states) outputs += (all_hidden_states,)
if self.output_attentions: if self.output_attentions:
outputs.append(all_attentions) outputs += (all_attentions,)
return outputs # last hidden state, (all hidden states), (all attentions) return outputs # last hidden state, (all hidden states), (all attentions)
...@@ -682,7 +682,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -682,7 +682,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
outputs = [lm_logits] + transformer_outputs[1:] outputs = (lm_logits,) + transformer_outputs[1:]
if lm_labels is not None: if lm_labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
...@@ -691,7 +691,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel): ...@@ -691,7 +691,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
return outputs # (loss), lm_logits, (all hidden states), (all attentions) return outputs # (loss), lm_logits, (all hidden states), (all attentions)
...@@ -785,18 +785,18 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -785,18 +785,18 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
outputs = [lm_logits, mc_logits] + transformer_outputs[1:] outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
if mc_labels is not None: if mc_labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)),
mc_labels.view(-1)) mc_labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
if lm_labels is not None: if lm_labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
return outputs # (lm loss), (mc loss), lm logits, mc logits, (all hidden_states), (attentions) return outputs # (lm loss), (mc loss), lm logits, mc logits, (all hidden_states), (attentions)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment