Unverified Commit 5a70a77b authored by Ahmed Elnaggar's avatar Ahmed Elnaggar Committed by GitHub
Browse files

Add Support to Gradient Checkpointing for LongT5 (#18977)

FlaxLongT5PreTrainedModel is missing "enable_gradient_checkpointing" function. This gives an error if someone tries to enable gradient checkpointing for longt5.
This pull request fixes it.
parent 4157e3cd
......@@ -1686,6 +1686,13 @@ class FlaxLongT5PreTrainedModel(FlaxPreTrainedModel):
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
def enable_gradient_checkpointing(self):
self._module = self.module_class(
config=self.config,
dtype=self.dtype,
gradient_checkpointing=True,
)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensors
input_ids = jnp.zeros(input_shape, dtype="i4")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment