"vscode:/vscode.git/clone" did not exist on "eaeea0719924211375e99fa4152d36d815f3b3bc"
Unverified Commit 51ede33c authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

truncate thinking tags in generations (#3145)

* feat: add postprocessing for generated text to strip stop sequences and thinking tokens

* nit

* fix: trim leading whitespace after stripping thinking tokens from generation

* feat: add think_end_token to model_args

* nit

* nit

* nit

* add to readme

* nit
parent 3102a8e4
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
--- ---
## Latest News 📣 ## Latest News 📣
- [2025/07] Added `think_end_token` arg to `hf` (token/str), `vllm` and `sglang` (str) for stripping CoT reasoning traces from models that support it.
- [2025/03] Added support for steering HF models! - [2025/03] Added support for steering HF models!
- [2025/02] Added [SGLang](https://docs.sglang.ai/) support! - [2025/02] Added [SGLang](https://docs.sglang.ai/) support!
- [2024/09] We are prototyping allowing users of LM Evaluation Harness to create and evaluate on text+image multimodal input, text output tasks, and have just added the `hf-multimodal` and `vllm-vlm` model types and `mmmu` task as a prototype feature. We welcome users to try out this in-progress feature and stress-test it for themselves, and suggest they check out [`lmms-eval`](https://github.com/EvolvingLMMs-Lab/lmms-eval), a wonderful project originally forking off of the lm-evaluation-harness, for a broader range of multimodal tasks, models, and features. - [2024/09] We are prototyping allowing users of LM Evaluation Harness to create and evaluate on text+image multimodal input, text output tasks, and have just added the `hf-multimodal` and `vllm-vlm` model types and `mmmu` task as a prototype feature. We welcome users to try out this in-progress feature and stress-test it for themselves, and suggest they check out [`lmms-eval`](https://github.com/EvolvingLMMs-Lab/lmms-eval), a wonderful project originally forking off of the lm-evaluation-harness, for a broader range of multimodal tasks, models, and features.
......
...@@ -34,6 +34,7 @@ from lm_eval.models.utils import ( ...@@ -34,6 +34,7 @@ from lm_eval.models.utils import (
get_dtype, get_dtype,
handle_stop_sequences, handle_stop_sequences,
pad_and_concat, pad_and_concat,
postprocess_generated_text,
stop_sequences_criteria, stop_sequences_criteria,
) )
...@@ -95,6 +96,9 @@ class HFLM(TemplateLM): ...@@ -95,6 +96,9 @@ class HFLM(TemplateLM):
autogptq: Optional[Union[bool, str]] = False, autogptq: Optional[Union[bool, str]] = False,
gptqmodel: Optional[bool] = False, gptqmodel: Optional[bool] = False,
gguf_file: Optional[str] = None, gguf_file: Optional[str] = None,
# end token for thinking, either the string or int token id.
# splits to get response after this token (if provided).
think_end_token: Union[str, int, None] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -224,6 +228,11 @@ class HFLM(TemplateLM): ...@@ -224,6 +228,11 @@ class HFLM(TemplateLM):
self.model.eval() self.model.eval()
self.model.tie_weights() self.model.tie_weights()
self.think_end_token = (
int(think_end_token)
if (isinstance(think_end_token, str) and think_end_token.isdigit())
else think_end_token
)
self.truncation = truncation self.truncation = truncation
self.logits_cache = logits_cache self.logits_cache = logits_cache
self.vocab_size = self.tokenizer.vocab_size self.vocab_size = self.tokenizer.vocab_size
...@@ -1427,15 +1436,30 @@ class HFLM(TemplateLM): ...@@ -1427,15 +1436,30 @@ class HFLM(TemplateLM):
if self.backend == "causal": if self.backend == "causal":
cont_toks = cont_toks[context_enc.shape[1] :] cont_toks = cont_toks[context_enc.shape[1] :]
s = self.tok_decode(cont_toks) # Handle integer think_end_token: find last occurrence and strip tokens after it
if isinstance(self.think_end_token, int):
think_token_indices = [
i
for i, token in enumerate(cont_toks)
if token == self.think_end_token
]
if think_token_indices:
cont_toks = cont_toks[think_token_indices[-1] + 1 :]
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc s = self.tok_decode(cont_toks)
for term in until:
if len(term) > 0:
# ignore '' separator,
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
s = s.split(term)[0]
# Strip leading whitespace if we removed thinking tokens
if isinstance(self.think_end_token, int):
s = s.lstrip()
# Apply post-processing: remove stop sequences and string-based thinking tokens
s = postprocess_generated_text(
generation=s,
stop=until,
think_end_token=self.think_end_token
if isinstance(self.think_end_token, str)
else None,
)
res.append(s) res.append(s)
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s) self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
......
...@@ -11,6 +11,7 @@ from lm_eval.api.registry import register_model ...@@ -11,6 +11,7 @@ from lm_eval.api.registry import register_model
from lm_eval.models.utils import ( from lm_eval.models.utils import (
Collator, Collator,
handle_stop_sequences, handle_stop_sequences,
postprocess_generated_text,
) )
from lm_eval.utils import ( from lm_eval.utils import (
get_rolling_token_windows, get_rolling_token_windows,
...@@ -59,6 +60,8 @@ class SGLangLM(TemplateLM): ...@@ -59,6 +60,8 @@ class SGLangLM(TemplateLM):
dp_size: int = 1, dp_size: int = 1,
tp_size: int = 1, tp_size: int = 1,
prefix_token_id: Optional[int] = None, prefix_token_id: Optional[int] = None,
# End marker for thinking tags - splits to get response after this token (if provided).
think_end_token: Optional[str] = None,
**kwargs, **kwargs,
): ):
super().__init__() super().__init__()
...@@ -74,6 +77,7 @@ class SGLangLM(TemplateLM): ...@@ -74,6 +77,7 @@ class SGLangLM(TemplateLM):
"Either context_length or max_model_len may be provided, but not both" "Either context_length or max_model_len may be provided, but not both"
) )
# Initialize your sglang model here # Initialize your sglang model here
self.think_end_token = think_end_token
self._max_length = ( self._max_length = (
max_model_len if max_model_len is not None else context_length max_model_len if max_model_len is not None else context_length
) )
...@@ -263,6 +267,9 @@ class SGLangLM(TemplateLM): ...@@ -263,6 +267,9 @@ class SGLangLM(TemplateLM):
# cache generations # cache generations
for output, context in zip(cont, context): for output, context in zip(cont, context):
generated_text = output.get("text", "") generated_text = output.get("text", "")
generated_text = postprocess_generated_text(
generated_text, until, self.think_end_token
)
res.append(generated_text) res.append(generated_text)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text "generate_until", (context, gen_kwargs), generated_text
......
...@@ -852,3 +852,32 @@ def truncate_tokens( ...@@ -852,3 +852,32 @@ def truncate_tokens(
right_length = max_length - left_length right_length = max_length - left_length
return tokens[:left_length] + tokens[-right_length:] return tokens[:left_length] + tokens[-right_length:]
return None return None
def postprocess_generated_text(
generation: str, stop: Union[list[str], str, None], think_end_token: Optional[str]
) -> str:
"""
Post-processes the generated text by stripping stop sequences and optional thinking markers.
Args:
generation (str): The generated text to be processed.
stop (Optional[list[str]]): Stop sequence(s) to remove. Text is truncated
at the first occurrence of any stop sequence.
think_end_token (Optional[str]): Token marking end of thinking section. If provided,
returns only the text after this token (discarding thinking content).
Returns:
str: The processed generation - text before stop sequences and after thinking sections.
"""
if stop:
stop = [stop] if isinstance(stop, str) else stop
for term in stop:
if len(term) > 0:
# ignore '' separator,
# for seq2seq case where self.tok_decode(self.eot_token_id) = ''
generation = generation.split(term)[0]
if think_end_token:
generation = generation.split(think_end_token)[-1].lstrip()
return generation
...@@ -22,6 +22,7 @@ from lm_eval.models.utils import ( ...@@ -22,6 +22,7 @@ from lm_eval.models.utils import (
Collator, Collator,
configure_pad_token, configure_pad_token,
handle_stop_sequences, handle_stop_sequences,
postprocess_generated_text,
undistribute, undistribute,
) )
from lm_eval.utils import ( from lm_eval.utils import (
...@@ -133,7 +134,10 @@ class VLLM(TemplateLM): ...@@ -133,7 +134,10 @@ class VLLM(TemplateLM):
device: str = "cuda", device: str = "cuda",
data_parallel_size: int = 1, data_parallel_size: int = 1,
lora_local_path: str = None, lora_local_path: str = None,
enable_thinking: bool = False, # VLLM: enable thinking tags in the prompt.
enable_thinking: bool = True,
# End marker for thinking tags - splits to get response after this token (if provided).
think_end_token: Optional[str] = None,
max_lora_rank: int = 16, max_lora_rank: int = 16,
**kwargs, **kwargs,
): ):
...@@ -148,6 +152,7 @@ class VLLM(TemplateLM): ...@@ -148,6 +152,7 @@ class VLLM(TemplateLM):
assert max_length is None or max_model_len is None, ( assert max_length is None or max_model_len is None, (
"Either max_length or max_model_len may be provided, but not both" "Either max_length or max_model_len may be provided, but not both"
) )
self.think_end_token = think_end_token
self.V1 = os.environ.get("VLLM_USE_V1", "1") != "0" self.V1 = os.environ.get("VLLM_USE_V1", "1") != "0"
self._max_length = max_model_len if max_model_len is not None else max_length self._max_length = max_model_len if max_model_len is not None else max_length
self.tensor_parallel_size = int(tensor_parallel_size) self.tensor_parallel_size = int(tensor_parallel_size)
...@@ -630,11 +635,11 @@ class VLLM(TemplateLM): ...@@ -630,11 +635,11 @@ class VLLM(TemplateLM):
# cache generations # cache generations
for output, context in zip(cont, context): for output, context in zip(cont, context):
generated_text = output.outputs[0].text generated_text: str = output.outputs[0].text
# use secondary stop seqs to cut off should-have-been-stopped content post-hoc # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
for term in until: generated_text = postprocess_generated_text(
if len(term) > 0: generated_text, until, self.think_end_token
generated_text = generated_text.split(term)[0] )
res.append(generated_text) res.append(generated_text)
self.cache_hook.add_partial( self.cache_hook.add_partial(
"generate_until", (context, gen_kwargs), generated_text "generate_until", (context, gen_kwargs), generated_text
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment