Unverified Commit b86bf441 authored by Frank Wang's avatar Frank Wang Committed by GitHub
Browse files

[Bugfix] Fix Random Dataset Prefix Length Inaccuracy (#33907)


Signed-off-by: default avatarfrankwang28 <frank.wbb@hotmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
parent de13dd78
...@@ -380,7 +380,7 @@ def gen_prompt_decode_to_target_len( ...@@ -380,7 +380,7 @@ def gen_prompt_decode_to_target_len(
max_retry: int = 10, max_retry: int = 10,
add_special_tokens: bool = False, add_special_tokens: bool = False,
rng: np.random.Generator | None = None, rng: np.random.Generator | None = None,
) -> tuple[str, list[int]]: ) -> tuple[str, list[int], int]:
""" """
Ensure decoded-then-encoded prompt length matches the target token length. Ensure decoded-then-encoded prompt length matches the target token length.
...@@ -392,7 +392,9 @@ def gen_prompt_decode_to_target_len( ...@@ -392,7 +392,9 @@ def gen_prompt_decode_to_target_len(
[6880, 6881] -> ['Ġcalls', 'here'] -> [6880, 6881] -> ['Ġcalls', 'here'] ->
[1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
Returns a tuple of the final prompt string and the adjusted token sequence. Returns a tuple of the final prompt string, the adjusted token sequence,
and the token mismatch (final_len - target_token_len) if the retry budget
is exhausted.
""" """
remain_num_try = max_retry remain_num_try = max_retry
token_mismatch = 0 token_mismatch = 0
...@@ -499,7 +501,7 @@ class RandomDataset(BenchmarkDataset): ...@@ -499,7 +501,7 @@ class RandomDataset(BenchmarkDataset):
allowed_tokens = np.array(list(set(all_tokens) - set(prohibited_tokens))) allowed_tokens = np.array(list(set(all_tokens) - set(prohibited_tokens)))
# Generate prefix once # Generate prefix once
prefix_token_ids = self.get_prefix(allowed_tokens, prefix_len) prefix_token_ids = self.get_prefix(tokenizer, allowed_tokens, prefix_len)
requests = [] requests = []
token_mismatch_total = 0 token_mismatch_total = 0
...@@ -554,19 +556,36 @@ class RandomDataset(BenchmarkDataset): ...@@ -554,19 +556,36 @@ class RandomDataset(BenchmarkDataset):
def get_prefix( def get_prefix(
self, self,
tokenizer: TokenizerLike,
allowed_tokens: np.ndarray, allowed_tokens: np.ndarray,
prefix_len: int, prefix_len: int,
) -> list[int]: ) -> list[int]:
""" """
Get the prefix for the dataset. Get the prefix for the dataset.
""" """
return ( if prefix_len <= 0:
allowed_tokens[ return []
prefix_tokens = allowed_tokens[
self._rng.integers(0, len(allowed_tokens), size=prefix_len) self._rng.integers(0, len(allowed_tokens), size=prefix_len)
].tolist() ].tolist()
if prefix_len > 0 _, adjusted_tokens, token_mismatch = gen_prompt_decode_to_target_len(
else [] tokenizer=tokenizer,
token_sequence=prefix_tokens,
target_token_len=prefix_len,
add_special_tokens=False,
rng=self._rng,
)
if token_mismatch != 0:
sign = "more" if token_mismatch > 0 else "fewer"
logger.warning(
"Prefix tokenization produced %d %s tokens than expected "
"after decoding and re-encoding. This is expected due to "
"the imperfect nature of the sampling procedure",
abs(token_mismatch),
sign,
) )
return adjusted_tokens
def get_sampling_params( def get_sampling_params(
self, self,
...@@ -1128,7 +1147,7 @@ class RandomMultiModalDataset(RandomDataset): ...@@ -1128,7 +1147,7 @@ class RandomMultiModalDataset(RandomDataset):
"Sampling from %d out of %d (vocab size)", len(allowed_tokens), vocab_size "Sampling from %d out of %d (vocab size)", len(allowed_tokens), vocab_size
) )
# Generate prefix once # Generate prefix once
prefix_token_ids = self.get_prefix(allowed_tokens, prefix_len) prefix_token_ids = self.get_prefix(tokenizer, allowed_tokens, prefix_len)
# Add synthetic multimodal items to each request # Add synthetic multimodal items to each request
mm_requests = [] mm_requests = []
token_mismatch_total = 0 token_mismatch_total = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment