Unverified Commit 33f0f67f authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Fix hunyuan 1.5 get_byt5_text_tokens (#565)

parent be303ba9
......@@ -257,6 +257,29 @@ class ByT5TextEncoder:
result = list(dict.fromkeys(result)) if len(result) > 1 else result
return result
def get_byt5_text_tokens(self, byt5_tokenizer, byt5_max_length, text_prompt):
"""
Tokenize text prompt for byT5 model.
Args:
byt5_tokenizer: The byT5 tokenizer
byt5_max_length: Maximum sequence length
text_prompt: Text prompt to tokenize
Returns:
Tuple of (input_ids, attention_mask)
"""
byt5_text_inputs = byt5_tokenizer(
text_prompt,
padding="max_length",
max_length=byt5_max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
return byt5_text_inputs.input_ids, byt5_text_inputs.attention_mask
def _process_single_byt5_prompt(self, prompt_text, device):
"""
Process a single prompt for byT5 encoding.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment