Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
835f8afc
"tests/vscode:/vscode.git/clone" did not exist on "fa22d9db4e9aa44adbf8ce0653dff1ec441bd5e6"
Commit
835f8afc
authored
Dec 08, 2024
by
Lianmin Zheng
Browse files
Migrate llama_classification to use the /classify interface (#2417)
parent
3844feb9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
25 deletions
+30
-25
python/sglang/srt/models/llama_classification.py
python/sglang/srt/models/llama_classification.py
+11
-22
scripts/deprecated/test_httpserver_classify.py
scripts/deprecated/test_httpserver_classify.py
+19
-3
No files found.
python/sglang/srt/models/llama_classification.py
View file @
835f8afc
...
...
@@ -18,7 +18,7 @@ import torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
sglang.srt.layers.
logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.
pooler
import
EmbeddingPoolerOutput
,
Pooler
,
PoolingType
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
...
...
@@ -40,7 +40,7 @@ class LlamaForClassification(nn.Module):
self
.
classification_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
classification_out_size
,
bias
=
False
)
self
.
eos_token_id
=
config
.
eos_token_id
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
False
)
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -49,28 +49,17 @@ class LlamaForClassification(nn.Module):
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
is_eos_token
=
input_ids
==
self
.
eos_token_id
hidden_states
=
hidden_states
[
is_eos_token
]
scores
=
self
.
classification_head
(
hidden_states
)
if
scores
.
shape
[
0
]
!=
forward_batch
.
batch_size
:
print
(
"Warning: the EOS tokens are missing in some sentences."
)
scores
=
torch
.
ones
(
(
forward_batch
.
batch_size
,
self
.
config
.
classification_out_size
)
).
to
(
input_ids
.
device
)
get_embedding
:
bool
=
True
,
)
->
EmbeddingPoolerOutput
:
assert
(
get_embedding
),
"LlamaForClassification is only used for embedding. Please add --is-embedding when you launch the server."
logits_output
=
LogitsProcessorOutput
(
next_token_logits
=
scores
,
next_token_logprobs
=
scores
,
normalized_prompt_logprobs
=
scores
,
input_token_logprobs
=
torch
.
ones_like
(
input_ids
),
input_top_logprobs
=
None
,
output_top_logprobs
=
None
,
)
hidden_states
=
self
.
model
(
input_ids
,
positions
,
forward_batch
,
input_embeds
)
last_token_hidden
=
self
.
pooler
(
hidden_states
,
forward_batch
).
embeddings
scores
=
self
.
classification_head
(
last_token_hidden
)
return
logits_output
return
EmbeddingPoolerOutput
(
scores
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
params_dict
=
dict
(
self
.
named_parameters
())
...
...
scripts/deprecated/test_httpserver_classify.py
View file @
835f8afc
"""
Usage:
python3 -m sglang.launch_server
--disable-cuda-graph
--model-path /model/llama-classification
python3 -m sglang.launch_server --model-path /model/llama-classification
--is-embedding --disable-radix-cache
python3 test_httpserver_classify.py
"""
...
...
@@ -11,7 +11,7 @@ import numpy as np
import
requests
def
get_logits
(
url
,
prompt
):
def
get_logits
_deprecated
(
url
:
str
,
prompt
:
str
):
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
...
...
@@ -25,7 +25,7 @@ def get_logits(url, prompt):
return
response
.
json
()[
"meta_info"
][
"normalized_prompt_logprob"
]
def
get_logits_batch
(
url
,
prompts
):
def
get_logits_batch
_deprecated
(
url
:
str
,
prompts
:
list
[
str
]
):
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
...
...
@@ -46,6 +46,22 @@ def get_logits_batch(url, prompts):
return
logits
def
get_logits
(
url
:
str
,
prompt
:
str
):
response
=
requests
.
post
(
url
+
"/classify"
,
json
=
{
"text"
:
prompt
},
)
return
response
.
json
()[
"embedding"
]
def
get_logits_batch
(
url
:
str
,
prompts
:
list
[
str
]):
response
=
requests
.
post
(
url
+
"/classify"
,
json
=
{
"text"
:
prompts
},
)
return
np
.
array
([
x
[
"embedding"
]
for
x
in
response
.
json
()])
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--host"
,
type
=
str
,
default
=
"http://127.0.0.1"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment