Unverified Commit be79cd7d authored by Matt's avatar Matt Committed by GitHub
Browse files

Protect `TFGenerationMixin.seed_generator` so it's not created at import (#18044)

parent 360719a6
...@@ -346,7 +346,14 @@ class TFGenerationMixin: ...@@ -346,7 +346,14 @@ class TFGenerationMixin:
A class containing all of the functions supporting generation, to be used as a mixin in [`TFPreTrainedModel`]. A class containing all of the functions supporting generation, to be used as a mixin in [`TFPreTrainedModel`].
""" """
seed_generator = tf.random.Generator.from_non_deterministic_state() _seed_generator = None
@property
def seed_generator(self):
if self._seed_generator is None:
self._seed_generator = tf.random.Generator.from_non_deterministic_state()
return self._seed_generator
supports_xla_generation = True supports_xla_generation = True
def prepare_inputs_for_generation(self, inputs, **kwargs): def prepare_inputs_for_generation(self, inputs, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment