Unverified Commit 50db7414 authored by Andy Ehrenberg's avatar Andy Ehrenberg Committed by GitHub
Browse files

check for None forced tokens (#21793)

parent 50644cf6
...@@ -328,7 +328,8 @@ class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor): ...@@ -328,7 +328,8 @@ class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor):
# Indexes without forced tokens will have a negative value. # Indexes without forced tokens will have a negative value.
force_token_array = jnp.ones((max(force_token_map.keys()) + 1), dtype=jnp.int32) * -1 force_token_array = jnp.ones((max(force_token_map.keys()) + 1), dtype=jnp.int32) * -1
for index, token in force_token_map.items(): for index, token in force_token_map.items():
force_token_array = force_token_array.at[index].set(token) if token is not None:
force_token_array = force_token_array.at[index].set(token)
self.force_token_array = jnp.int32(force_token_array) self.force_token_array = jnp.int32(force_token_array)
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray: def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment