Unverified Commit 7a7ac6be authored by Yang Zheng's avatar Yang Zheng Committed by GitHub
Browse files

[FIX] Update EOS from config (#2475)

parent d9e6ee38
......@@ -15,7 +15,8 @@
import json
import logging
from enum import IntEnum, auto
from typing import List, Optional, Union
from functools import lru_cache
from typing import List, Optional, Set, Union
import torch
from transformers import PretrainedConfig
......@@ -271,6 +272,14 @@ class ModelConfig:
self.quantization,
)
@lru_cache()
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
eos_ids = getattr(self.hf_config, "eos_token_id", None)
if eos_ids:
# it can be either int or list of int
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
return eos_ids
def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models.
......
......@@ -29,7 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
import dataclasses
import logging
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Set, Tuple, Union
import numpy as np
import torch
......@@ -209,6 +209,7 @@ class Req:
lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None,
session_id: Optional[str] = None,
eos_token_ids: Optional[Set[int]] = None,
):
# Input and output info
self.rid = rid
......@@ -236,6 +237,7 @@ class Req:
self.finished_reason = None
self.to_abort = False
self.stream = stream
self.eos_token_ids = eos_token_ids
# For incremental decoding
# ----- | --------- read_ids -------|
......@@ -395,18 +397,23 @@ class Req:
last_token_id = self.output_ids[-1]
matched_eos = False
# Check stop token ids
if self.sampling_params.stop_token_ids:
matched_eos = last_token_id in self.sampling_params.stop_token_ids
if self.tokenizer is not None:
matched_eos |= last_token_id == self.tokenizer.eos_token_id
if self.tokenizer.additional_stop_token_ids:
matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids
if matched_eos and not self.sampling_params.ignore_eos:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
return
if not self.sampling_params.ignore_eos:
matched_eos = False
# Check stop token ids
if self.sampling_params.stop_token_ids:
matched_eos = last_token_id in self.sampling_params.stop_token_ids
if self.eos_token_ids:
matched_eos |= last_token_id in self.eos_token_ids
if self.tokenizer is not None:
matched_eos |= last_token_id == self.tokenizer.eos_token_id
if self.tokenizer.additional_stop_token_ids:
matched_eos |= (
last_token_id in self.tokenizer.additional_stop_token_ids
)
if matched_eos:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
return
# Check stop strings
if len(self.sampling_params.stop_strs) > 0:
......
......@@ -517,6 +517,7 @@ class Scheduler:
stream=recv_req.stream,
lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds,
eos_token_ids=self.model_config.get_hf_eos_token_id(),
)
req.tokenizer = self.tokenizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment