Unverified Commit 7a7ac6be authored by Yang Zheng's avatar Yang Zheng Committed by GitHub
Browse files

[FIX] Update EOS from config (#2475)

parent d9e6ee38
...@@ -15,7 +15,8 @@ ...@@ -15,7 +15,8 @@
import json import json
import logging import logging
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import List, Optional, Union from functools import lru_cache
from typing import List, Optional, Set, Union
import torch import torch
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -271,6 +272,14 @@ class ModelConfig: ...@@ -271,6 +272,14 @@ class ModelConfig:
self.quantization, self.quantization,
) )
@lru_cache()
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
eos_ids = getattr(self.hf_config, "eos_token_id", None)
if eos_ids:
# it can be either int or list of int
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
return eos_ids
def get_hf_text_config(config: PretrainedConfig): def get_hf_text_config(config: PretrainedConfig):
"""Get the "sub" config relevant to llm for multi modal models. """Get the "sub" config relevant to llm for multi modal models.
......
...@@ -29,7 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch ...@@ -29,7 +29,7 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
import dataclasses import dataclasses
import logging import logging
from typing import List, Optional, Tuple, Union from typing import List, Optional, Set, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -209,6 +209,7 @@ class Req: ...@@ -209,6 +209,7 @@ class Req:
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None, input_embeds: Optional[List[List[float]]] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
eos_token_ids: Optional[Set[int]] = None,
): ):
# Input and output info # Input and output info
self.rid = rid self.rid = rid
...@@ -236,6 +237,7 @@ class Req: ...@@ -236,6 +237,7 @@ class Req:
self.finished_reason = None self.finished_reason = None
self.to_abort = False self.to_abort = False
self.stream = stream self.stream = stream
self.eos_token_ids = eos_token_ids
# For incremental decoding # For incremental decoding
# ----- | --------- read_ids -------| # ----- | --------- read_ids -------|
...@@ -395,18 +397,23 @@ class Req: ...@@ -395,18 +397,23 @@ class Req:
last_token_id = self.output_ids[-1] last_token_id = self.output_ids[-1]
matched_eos = False if not self.sampling_params.ignore_eos:
matched_eos = False
# Check stop token ids
if self.sampling_params.stop_token_ids: # Check stop token ids
matched_eos = last_token_id in self.sampling_params.stop_token_ids if self.sampling_params.stop_token_ids:
if self.tokenizer is not None: matched_eos = last_token_id in self.sampling_params.stop_token_ids
matched_eos |= last_token_id == self.tokenizer.eos_token_id if self.eos_token_ids:
if self.tokenizer.additional_stop_token_ids: matched_eos |= last_token_id in self.eos_token_ids
matched_eos |= last_token_id in self.tokenizer.additional_stop_token_ids if self.tokenizer is not None:
if matched_eos and not self.sampling_params.ignore_eos: matched_eos |= last_token_id == self.tokenizer.eos_token_id
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id) if self.tokenizer.additional_stop_token_ids:
return matched_eos |= (
last_token_id in self.tokenizer.additional_stop_token_ids
)
if matched_eos:
self.finished_reason = FINISH_MATCHED_TOKEN(matched=last_token_id)
return
# Check stop strings # Check stop strings
if len(self.sampling_params.stop_strs) > 0: if len(self.sampling_params.stop_strs) > 0:
......
...@@ -517,6 +517,7 @@ class Scheduler: ...@@ -517,6 +517,7 @@ class Scheduler:
stream=recv_req.stream, stream=recv_req.stream,
lora_path=recv_req.lora_path, lora_path=recv_req.lora_path,
input_embeds=recv_req.input_embeds, input_embeds=recv_req.input_embeds,
eos_token_ids=self.model_config.get_hf_eos_token_id(),
) )
req.tokenizer = self.tokenizer req.tokenizer = self.tokenizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment