Unverified Commit 62f71561 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[coati] fix inference profanity check (#3299)

parent 5134ad5d
......@@ -10,3 +10,4 @@ uvicorn
git+https://github.com/huggingface/transformers
accelerate
bitsandbytes
jieba
\ No newline at end of file
......@@ -2,6 +2,7 @@ import re
from threading import Lock
from typing import Any, Callable, Generator, List, Optional
import json
import jieba
import torch
import torch.distributed as dist
......@@ -130,10 +131,7 @@ class ChatPromptProcessor:
self.tokenizer = tokenizer
self.context = context
self.max_len = max_len
if len(censored_words) > 0:
self.censored_pat = re.compile(f'({"|".join(map(re.escape, censored_words))})', flags=re.I)
else:
self.censored_pat = None
self.censored_words = set([word.lower() for word in censored_words])
# These will be initialized after the first call of preprocess_prompt()
self.context_len: Optional[int] = None
self.dialogue_placeholder_len: Optional[int] = None
......@@ -179,9 +177,10 @@ class ChatPromptProcessor:
return output.strip()
def has_censored_words(self, text: str) -> bool:
if self.censored_pat is None:
if len(self.censored_words) == 0:
return False
return self.censored_pat.search(text) is not None
intersection = set(jieba.cut(text.lower())) & self.censored_words
return len(intersection) > 0
class LockedIterator:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment