"docs/source/vscode:/vscode.git/clone" did not exist on "7032e0203262ebb2ebf55da8d2e01f873973e835"
Unverified Commit 65001cb1 authored by calpt's avatar calpt Committed by GitHub
Browse files

Loosen output shape restrictions on GPT-style models (#25188)

* Loosen output shape restrictions on GPT-style models

* Use more self-explanatory variables

* Revert "Use more self-explanatory variables"

This reverts commit 5fd9ab39119558b7e750f61aa4a19014dccc5ed5.
parent d6bfba76
......@@ -603,7 +603,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
......
......@@ -849,7 +849,7 @@ class GPT2Model(GPT2PreTrainedModel):
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
......
......@@ -588,7 +588,7 @@ class GPTNeoModel(GPTNeoPreTrainedModel):
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
......
......@@ -641,7 +641,7 @@ class GPTJModel(GPTJPreTrainedModel):
hidden_states = self.drop(hidden_states)
output_shape = input_shape + (hidden_states.size(-1),)
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment