Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
73b542a1
Unverified
Commit
73b542a1
authored
Mar 29, 2023
by
ver217
Committed by
GitHub
Mar 29, 2023
Browse files
[coati] inference supports profanity check (#3295)
parent
ce2cafae
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
4 deletions
+29
-4
applications/Chat/inference/server.py
applications/Chat/inference/server.py
+14
-3
applications/Chat/inference/utils.py
applications/Chat/inference/utils.py
+15
-1
No files found.
applications/Chat/inference/server.py
View file @
73b542a1
...
@@ -14,7 +14,7 @@ from slowapi.errors import RateLimitExceeded
...
@@ -14,7 +14,7 @@ from slowapi.errors import RateLimitExceeded
from
slowapi.util
import
get_remote_address
from
slowapi.util
import
get_remote_address
from
sse_starlette.sse
import
EventSourceResponse
from
sse_starlette.sse
import
EventSourceResponse
from
transformers
import
AutoTokenizer
,
GenerationConfig
,
LlamaForCausalLM
from
transformers
import
AutoTokenizer
,
GenerationConfig
,
LlamaForCausalLM
from
utils
import
ChatPromptProcessor
,
Dialogue
,
LockedIterator
,
sample_streamingly
,
update_model_kwargs_fn
from
utils
import
ChatPromptProcessor
,
Dialogue
,
LockedIterator
,
sample_streamingly
,
update_model_kwargs_fn
,
load_json
CONTEXT
=
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
CONTEXT
=
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
MAX_LEN
=
512
MAX_LEN
=
512
...
@@ -111,6 +111,8 @@ def generate(data: GenerationTaskReq, request: Request):
...
@@ -111,6 +111,8 @@ def generate(data: GenerationTaskReq, request: Request):
@
limiter
.
limit
(
'1/second'
)
@
limiter
.
limit
(
'1/second'
)
def
generate_no_stream
(
data
:
GenerationTaskReq
,
request
:
Request
):
def
generate_no_stream
(
data
:
GenerationTaskReq
,
request
:
Request
):
prompt
=
prompt_processor
.
preprocess_prompt
(
data
.
history
,
data
.
max_new_tokens
)
prompt
=
prompt_processor
.
preprocess_prompt
(
data
.
history
,
data
.
max_new_tokens
)
if
prompt_processor
.
has_censored_words
(
prompt
):
return
prompt_processor
.
SAFE_RESPONSE
inputs
=
{
k
:
v
.
cuda
()
for
k
,
v
in
tokenizer
(
prompt
,
return_tensors
=
"pt"
).
items
()}
inputs
=
{
k
:
v
.
cuda
()
for
k
,
v
in
tokenizer
(
prompt
,
return_tensors
=
"pt"
).
items
()}
with
running_lock
:
with
running_lock
:
output
=
model
.
generate
(
**
inputs
,
**
data
.
dict
(
exclude
=
{
'history'
}))
output
=
model
.
generate
(
**
inputs
,
**
data
.
dict
(
exclude
=
{
'history'
}))
...
@@ -118,7 +120,10 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
...
@@ -118,7 +120,10 @@ def generate_no_stream(data: GenerationTaskReq, request: Request):
prompt_len
=
inputs
[
'input_ids'
].
size
(
1
)
prompt_len
=
inputs
[
'input_ids'
].
size
(
1
)
response
=
output
[
0
,
prompt_len
:]
response
=
output
[
0
,
prompt_len
:]
out_string
=
tokenizer
.
decode
(
response
,
skip_special_tokens
=
True
)
out_string
=
tokenizer
.
decode
(
response
,
skip_special_tokens
=
True
)
return
prompt_processor
.
postprocess_output
(
out_string
)
out_string
=
prompt_processor
.
postprocess_output
(
out_string
)
if
prompt_processor
.
has_censored_words
(
out_string
):
return
prompt_processor
.
SAFE_RESPONSE
return
out_string
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
@@ -140,13 +145,19 @@ if __name__ == '__main__':
...
@@ -140,13 +145,19 @@ if __name__ == '__main__':
help
=
'Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.'
)
help
=
'Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.'
)
parser
.
add_argument
(
'--http_host'
,
default
=
'0.0.0.0'
)
parser
.
add_argument
(
'--http_host'
,
default
=
'0.0.0.0'
)
parser
.
add_argument
(
'--http_port'
,
type
=
int
,
default
=
7070
)
parser
.
add_argument
(
'--http_port'
,
type
=
int
,
default
=
7070
)
parser
.
add_argument
(
'--profanity_file'
,
default
=
None
,
help
=
'Path to profanity words list. It should be a JSON file containing a list of words.'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
quant
==
'4bit'
:
if
args
.
quant
==
'4bit'
:
assert
args
.
gptq_checkpoint
is
not
None
,
'Please specify a GPTQ checkpoint.'
assert
args
.
gptq_checkpoint
is
not
None
,
'Please specify a GPTQ checkpoint.'
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
pretrained
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
pretrained
)
prompt_processor
=
ChatPromptProcessor
(
tokenizer
,
CONTEXT
,
MAX_LEN
)
if
args
.
profanity_file
is
not
None
:
censored_words
=
load_json
(
args
.
profanity_file
)
else
:
censored_words
=
[]
prompt_processor
=
ChatPromptProcessor
(
tokenizer
,
CONTEXT
,
MAX_LEN
,
censored_words
=
censored_words
)
if
args
.
quant
==
'4bit'
:
if
args
.
quant
==
'4bit'
:
model
=
load_quant
(
args
.
pretrained
,
args
.
gptq_checkpoint
,
4
,
args
.
gptq_group_size
)
model
=
load_quant
(
args
.
pretrained
,
args
.
gptq_checkpoint
,
4
,
args
.
gptq_group_size
)
...
...
applications/Chat/inference/utils.py
View file @
73b542a1
import
re
import
re
from
threading
import
Lock
from
threading
import
Lock
from
typing
import
Any
,
Callable
,
Generator
,
List
,
Optional
from
typing
import
Any
,
Callable
,
Generator
,
List
,
Optional
import
json
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -123,11 +124,16 @@ STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
...
@@ -123,11 +124,16 @@ STOP_PAT = re.compile(r'(###|instruction:).*', flags=(re.I | re.S))
class
ChatPromptProcessor
:
class
ChatPromptProcessor
:
SAFE_RESPONSE
=
'The input/response contains inappropriate content, please rephrase your prompt.'
def
__init__
(
self
,
tokenizer
,
context
:
str
,
max_len
:
int
=
2048
):
def
__init__
(
self
,
tokenizer
,
context
:
str
,
max_len
:
int
=
2048
,
censored_words
:
List
[
str
]
=
[]
):
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
self
.
context
=
context
self
.
context
=
context
self
.
max_len
=
max_len
self
.
max_len
=
max_len
if
len
(
censored_words
)
>
0
:
self
.
censored_pat
=
re
.
compile
(
f
'(
{
"|"
.
join
(
map
(
re
.
escape
,
censored_words
))
}
)'
,
flags
=
re
.
I
)
else
:
self
.
censored_pat
=
None
# These will be initialized after the first call of preprocess_prompt()
# These will be initialized after the first call of preprocess_prompt()
self
.
context_len
:
Optional
[
int
]
=
None
self
.
context_len
:
Optional
[
int
]
=
None
self
.
dialogue_placeholder_len
:
Optional
[
int
]
=
None
self
.
dialogue_placeholder_len
:
Optional
[
int
]
=
None
...
@@ -172,6 +178,10 @@ class ChatPromptProcessor:
...
@@ -172,6 +178,10 @@ class ChatPromptProcessor:
output
=
STOP_PAT
.
sub
(
''
,
output
)
output
=
STOP_PAT
.
sub
(
''
,
output
)
return
output
.
strip
()
return
output
.
strip
()
def
has_censored_words
(
self
,
text
:
str
)
->
bool
:
if
self
.
censored_pat
is
None
:
return
False
return
self
.
censored_pat
.
search
(
text
)
is
not
None
class
LockedIterator
:
class
LockedIterator
:
...
@@ -185,3 +195,7 @@ class LockedIterator:
...
@@ -185,3 +195,7 @@ class LockedIterator:
def
__next__
(
self
):
def
__next__
(
self
):
with
self
.
lock
:
with
self
.
lock
:
return
next
(
self
.
it
)
return
next
(
self
.
it
)
def
load_json
(
path
:
str
):
with
open
(
path
)
as
f
:
return
json
.
load
(
f
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment