Commit d8e83de7 authored by LysandreJik's avatar LysandreJik
Browse files

GPT2 can be exported to TorchScript

parent e891bb43
...@@ -328,7 +328,8 @@ class GPT2LMHead(nn.Module): ...@@ -328,7 +328,8 @@ class GPT2LMHead(nn.Module):
def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True): def set_embeddings_weights(self, model_embeddings_weights, predict_special_tokens=True):
self.predict_special_tokens = predict_special_tokens self.predict_special_tokens = predict_special_tokens
self.decoder.weight = model_embeddings_weights # Tied weights # Export to TorchScript can't handle parameter sharing so we are cloning them.
self.decoder.weight = nn.Parameter(model_embeddings_weights.clone()) # Tied weights
def forward(self, hidden_state): def forward(self, hidden_state):
lm_logits = self.decoder(hidden_state) lm_logits = self.decoder(hidden_state)
...@@ -557,16 +558,16 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -557,16 +558,16 @@ class GPT2Model(GPT2PreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
presents = [] presents = ()
all_attentions = [] all_attentions = []
all_hidden_states = [] all_hidden_states = ()
for i, (block, layer_past) in enumerate(zip(self.h, past)): for i, (block, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states.append(hidden_states.view(*output_shape)) all_hidden_states += (hidden_states.view(*output_shape),)
outputs = block(hidden_states, layer_past, head_mask[i]) outputs = block(hidden_states, layer_past, head_mask[i])
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
presents.append(present) presents += (present,)
if self.output_attentions: if self.output_attentions:
all_attentions.append(outputs[2]) all_attentions.append(outputs[2])
...@@ -576,16 +577,16 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -576,16 +577,16 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states = hidden_states.view(*output_shape) hidden_states = hidden_states.view(*output_shape)
# Add last hidden state # Add last hidden state
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states.append(hidden_states) all_hidden_states += (hidden_states,)
outputs = [hidden_states, presents] outputs = (hidden_states, presents)
if self.output_hidden_states: if self.output_hidden_states:
outputs.append(all_hidden_states) outputs += (all_hidden_states,)
if self.output_attentions: if self.output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning # let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:] attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
all_attentions = list(t.view(*attention_output_shape) for t in all_attentions) all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
outputs.append(all_attentions) outputs += (all_attentions,)
return outputs # last hidden state, presents, (all hidden_states), (attentions) return outputs # last hidden state, presents, (all hidden_states), (attentions)
...@@ -658,7 +659,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -658,7 +659,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
outputs = [lm_logits] + transformer_outputs[1:] outputs = (lm_logits,) + transformer_outputs[1:]
if lm_labels is not None: if lm_labels is not None:
# Shift so that tokens < n predict n # Shift so that tokens < n predict n
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
...@@ -667,7 +668,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -667,7 +668,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions) return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions)
...@@ -750,18 +751,18 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -750,18 +751,18 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
lm_logits = self.lm_head(hidden_states) lm_logits = self.lm_head(hidden_states)
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids) mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids)
outputs = [lm_logits, mc_logits] + transformer_outputs[1:] outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
if mc_labels is not None: if mc_labels is not None:
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)), loss = loss_fct(mc_logits.view(-1, mc_logits.size(-1)),
mc_labels.view(-1)) mc_labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
if lm_labels is not None: if lm_labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous() shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous() shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1) loss_fct = CrossEntropyLoss(ignore_index=-1)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)) shift_labels.view(-1))
outputs = [loss] + outputs outputs = (loss,) + outputs
return outputs # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions) return outputs # (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment