Unverified Commit 9432ed8c authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Explicitly set `return_dict` for `apply_chat_template` (#33372)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 726d8972
...@@ -38,8 +38,8 @@ def get_prompt_embeds( ...@@ -38,8 +38,8 @@ def get_prompt_embeds(
embedding_layer: torch.nn.Module, embedding_layer: torch.nn.Module,
): ):
token_ids = tokenizer.apply_chat_template( token_ids = tokenizer.apply_chat_template(
chat, add_generation_prompt=True, return_tensors="pt" chat, add_generation_prompt=True, return_tensors="pt", return_dict=True
) ).input_ids
prompt_embeds = embedding_layer(token_ids).squeeze(0) prompt_embeds = embedding_layer(token_ids).squeeze(0)
return prompt_embeds return prompt_embeds
......
...@@ -49,8 +49,8 @@ def main(): ...@@ -49,8 +49,8 @@ def main():
# Refer to the HuggingFace repo for the correct format to use # Refer to the HuggingFace repo for the correct format to use
chat = [{"role": "user", "content": "Please tell me about the capital of France."}] chat = [{"role": "user", "content": "Please tell me about the capital of France."}]
token_ids = tokenizer.apply_chat_template( token_ids = tokenizer.apply_chat_template(
chat, add_generation_prompt=True, return_tensors="pt" chat, add_generation_prompt=True, return_tensors="pt", return_dict=True
) ).input_ids
embedding_layer = transformers_model.get_input_embeddings() embedding_layer = transformers_model.get_input_embeddings()
prompt_embeds = embedding_layer(token_ids).squeeze(0) prompt_embeds = embedding_layer(token_ids).squeeze(0)
......
...@@ -27,7 +27,8 @@ def main(client): ...@@ -27,7 +27,8 @@ def main(client):
messages, messages,
add_generation_prompt=True, add_generation_prompt=True,
enable_thinking=False, enable_thinking=False,
) return_dict=True,
).input_ids
payload = { payload = {
"model": MODEL_NAME, "model": MODEL_NAME,
"token_ids": token_ids, "token_ids": token_ids,
......
...@@ -92,7 +92,8 @@ async def test_same_response_as_chat_completions(client, tokenizer, messages): ...@@ -92,7 +92,8 @@ async def test_same_response_as_chat_completions(client, tokenizer, messages):
messages, messages,
add_generation_prompt=True, add_generation_prompt=True,
enable_thinking=False, # default with Qwen3 enable_thinking=False, # default with Qwen3
) return_dict=True, # default with Transformers v5
).input_ids
for ignore_eos in [True, False]: for ignore_eos in [True, False]:
payload = { payload = {
...@@ -155,7 +156,8 @@ async def test_stop_string_workflow(client, tokenizer, messages): ...@@ -155,7 +156,8 @@ async def test_stop_string_workflow(client, tokenizer, messages):
messages, messages,
add_generation_prompt=True, add_generation_prompt=True,
enable_thinking=False, # default with Qwen3 enable_thinking=False, # default with Qwen3
) return_dict=True, # default with Transformers v5
).input_ids
payload = { payload = {
"model": MODEL_NAME, "model": MODEL_NAME,
"token_ids": token_ids, "token_ids": token_ids,
...@@ -251,7 +253,8 @@ async def test_generate_with_lora_adapter(client, tokenizer, messages): ...@@ -251,7 +253,8 @@ async def test_generate_with_lora_adapter(client, tokenizer, messages):
messages, messages,
add_generation_prompt=True, add_generation_prompt=True,
enable_thinking=False, # default with Qwen3 enable_thinking=False, # default with Qwen3
) return_dict=True, # default with Transformers v5
).input_ids
payload = { payload = {
"model": "Alice", "model": "Alice",
"token_ids": token_ids, "token_ids": token_ids,
......
...@@ -759,6 +759,7 @@ class IsaacProcessor: ...@@ -759,6 +759,7 @@ class IsaacProcessor:
# Regular text message # Regular text message
processed_messages.append(message) processed_messages.append(message)
kwargs["return_dict"] = False
return self.tokenizer.apply_chat_template( return self.tokenizer.apply_chat_template(
processed_messages, processed_messages,
tokenize=tokenize, tokenize=tokenize,
......
...@@ -70,6 +70,7 @@ class DeepseekV32Renderer(RendererLike): ...@@ -70,6 +70,7 @@ class DeepseekV32Renderer(RendererLike):
content_format="string", content_format="string",
) )
kwargs["return_dict"] = False
prompt_raw = tokenizer.apply_chat_template( prompt_raw = tokenizer.apply_chat_template(
conversation=conversation, conversation=conversation,
messages=messages, messages=messages,
...@@ -100,6 +101,7 @@ class DeepseekV32Renderer(RendererLike): ...@@ -100,6 +101,7 @@ class DeepseekV32Renderer(RendererLike):
content_format="string", content_format="string",
) )
kwargs["return_dict"] = False
prompt_raw = tokenizer.apply_chat_template( prompt_raw = tokenizer.apply_chat_template(
conversation=conversation, conversation=conversation,
messages=messages, messages=messages,
......
...@@ -70,6 +70,7 @@ class Grok2Renderer(RendererLike): ...@@ -70,6 +70,7 @@ class Grok2Renderer(RendererLike):
content_format="string", content_format="string",
) )
kwargs["return_dict"] = False
prompt_raw = tokenizer.apply_chat_template( prompt_raw = tokenizer.apply_chat_template(
conversation=conversation, conversation=conversation,
messages=messages, messages=messages,
...@@ -100,6 +101,7 @@ class Grok2Renderer(RendererLike): ...@@ -100,6 +101,7 @@ class Grok2Renderer(RendererLike):
content_format="string", content_format="string",
) )
kwargs["return_dict"] = False
prompt_raw = tokenizer.apply_chat_template( prompt_raw = tokenizer.apply_chat_template(
conversation=conversation, conversation=conversation,
messages=messages, messages=messages,
......
...@@ -465,6 +465,7 @@ def safe_apply_chat_template( ...@@ -465,6 +465,7 @@ def safe_apply_chat_template(
chat_template=chat_template, chat_template=chat_template,
chat_template_kwargs=kwargs, chat_template_kwargs=kwargs,
) )
resolved_kwargs["return_dict"] = False
try: try:
return tokenizer.apply_chat_template( return tokenizer.apply_chat_template(
......
...@@ -432,6 +432,7 @@ class Grok2Tokenizer(TokenizerLike): ...@@ -432,6 +432,7 @@ class Grok2Tokenizer(TokenizerLike):
raise ValueError( raise ValueError(
"No chat template available. Provide `chat_template` explicitly." "No chat template available. Provide `chat_template` explicitly."
) )
kwargs["return_dict"] = False
prompt = hf_chat_utils.apply_chat_template( prompt = hf_chat_utils.apply_chat_template(
conversation=messages, conversation=messages,
chat_template=template, chat_template=template,
......
...@@ -148,8 +148,8 @@ class HunYuanVLProcessor(ProcessorMixin): ...@@ -148,8 +148,8 @@ class HunYuanVLProcessor(ProcessorMixin):
assert 0 assert 0
def apply_chat_template(self, *args, **kwargs): def apply_chat_template(self, *args, **kwargs):
token_ids = self.tokenizer.apply_chat_template(*args, **kwargs) kwargs["return_dict"] = False
return token_ids return self.tokenizer.apply_chat_template(*args, **kwargs)
def get_imgs_pos(self, doc_ids): def get_imgs_pos(self, doc_ids):
doc_ids = np.array(doc_ids, dtype=np.int64) doc_ids = np.array(doc_ids, dtype=np.int64)
......
...@@ -213,6 +213,7 @@ class Qwen3ASRProcessor(ProcessorMixin): ...@@ -213,6 +213,7 @@ class Qwen3ASRProcessor(ProcessorMixin):
return list(_iter()) return list(_iter())
def apply_chat_template(self, conversations, chat_template=None, **kwargs): def apply_chat_template(self, conversations, chat_template=None, **kwargs):
kwargs["return_dict"] = False
return super().apply_chat_template(conversations, chat_template, **kwargs) return super().apply_chat_template(conversations, chat_template, **kwargs)
@property @property
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment