Unverified Commit 69511cdc authored by Suraj Patil's avatar Suraj Patil Committed by GitHub
Browse files

unfreeze initial cache in gpt models (#14535)

parent 2318bf77
......@@ -444,7 +444,7 @@ class FlaxGPT2PreTrainedModel(FlaxPreTrainedModel):
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
return init_variables["cache"]
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
def __call__(
......
......@@ -388,7 +388,7 @@ class FlaxGPTNeoPreTrainedModel(FlaxPreTrainedModel):
init_variables = self.module.init(
jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
)
return init_variables["cache"]
return unfreeze(init_variables["cache"])
@add_start_docstrings_to_model_forward(GPT_NEO_INPUTS_DOCSTRING)
def __call__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment