Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1aea19f6
Unverified
Commit
1aea19f6
authored
Nov 25, 2024
by
Rin Intachuen
Committed by
GitHub
Nov 25, 2024
Browse files
Input_embeds support (#2052)
parent
1f76fc6e
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
204 additions
and
15 deletions
+204
-15
docs/references/sampling_params.md
docs/references/sampling_params.md
+6
-4
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+25
-5
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+21
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+8
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+14
-3
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+4
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+11
-3
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_input_embeddings.py
test/srt/test_input_embeddings.py
+114
-0
No files found.
docs/references/sampling_params.md
View file @
1aea19f6
...
...
@@ -11,21 +11,23 @@ The `/generate` endpoint accepts the following arguments in the JSON format.
class
GenerateReqInput
:
# The input prompt. It can be a single prompt or a batch of prompts.
text
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The token ids for text; one can
either
specify text or input_ids
.
# The token ids for text; one can specify
either
text or input_ids
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
input_embeds
:
Optional
[
Union
[
List
[
List
[
List
[
float
]]],
List
[
List
[
float
]]]]
=
None
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The sampling_params. See descriptions below.
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
sampling_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]
]
=
None
# The request id.
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Whether to return logprobs.
return_logprob
:
Optional
[
Union
[
List
[
bool
],
bool
]]
=
None
#
T
he start location
of
the prompt for return
_
logprob.
#
If return logprobs, t
he start location
in
the prompt for return
ing
logprob
s
.
# By default, this value is "-1", which means it will only return logprobs for output tokens.
logprob_start_len
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
#
T
he number of top logprobs to return.
#
If return logprobs, t
he number of top logprobs to return
at each position
.
top_logprobs_num
:
Optional
[
Union
[
List
[
int
],
int
]]
=
None
# Whether to detokenize tokens in text in the returned logprobs.
return_text_in_logprobs
:
bool
=
False
...
...
python/sglang/srt/managers/io_struct.py
View file @
1aea19f6
...
...
@@ -29,8 +29,10 @@ from sglang.srt.sampling.sampling_params import SamplingParams
class
GenerateReqInput
:
# The input prompt. It can be a single prompt or a batch of prompts.
text
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The token ids for text; one can
either
specify text or input_ids
.
# The token ids for text; one can specify
either
text or input_ids
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
# The embeddings for input_ids; one can specify either text or input_ids or input_embeds.
input_embeds
:
Optional
[
Union
[
List
[
List
[
List
[
float
]]],
List
[
List
[
float
]]]]
=
None
# The image input. It can be a file name, a url, or base64 encoded string.
# See also python/sglang/srt/utils.py:load_image.
image_data
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
...
...
@@ -60,10 +62,16 @@ class GenerateReqInput:
]
=
None
def
normalize_batch_and_arguments
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
and
self
.
input_embeds
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
and
self
.
input_embeds
is
not
None
):
raise
ValueError
(
"Either text or input_ids should be provided."
)
raise
ValueError
(
"Either text, input_ids or input_embeds should be provided."
)
# Derive the batch size
if
self
.
text
is
not
None
:
...
...
@@ -73,13 +81,21 @@ class GenerateReqInput:
else
:
self
.
is_single
=
False
self
.
batch_size
=
len
(
self
.
text
)
else
:
self
.
input_embeds
=
None
elif
self
.
input_ids
is
not
None
:
if
isinstance
(
self
.
input_ids
[
0
],
int
):
self
.
is_single
=
True
self
.
batch_size
=
1
else
:
self
.
is_single
=
False
self
.
batch_size
=
len
(
self
.
input_ids
)
self
.
input_embeds
=
None
else
:
if
isinstance
(
self
.
input_embeds
[
0
][
0
],
float
):
self
.
is_single
=
True
self
.
batch_size
=
1
else
:
self
.
batch_size
=
len
(
self
.
input_embeds
)
# Handle parallel sampling
# When parallel sampling is used, we always treat the input as a batch.
...
...
@@ -202,6 +218,8 @@ class TokenizedGenerateReqInput:
# LoRA related
lora_path
:
Optional
[
str
]
=
None
# None means just use the base model
# The input embeds
input_embeds
:
Optional
[
Union
[
List
[
List
[
List
[
float
]]],
List
[
List
[
float
]]]]
=
None
# Session id info for continual prompting
session_id
:
Optional
[
str
]
=
None
...
...
@@ -218,6 +236,8 @@ class EmbeddingReqInput:
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Dummy sampling params for compatibility
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
# Dummy input embeds for compatibility
input_embeds
:
Optional
[
Union
[
List
[
List
[
List
[
float
]]],
List
[
List
[
float
]]]]
=
None
def
normalize_batch_and_arguments
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
1aea19f6
...
...
@@ -178,6 +178,7 @@ class Req:
origin_input_ids
:
Tuple
[
int
],
sampling_params
:
SamplingParams
,
lora_path
:
Optional
[
str
]
=
None
,
input_embeds
:
Optional
[
List
[
List
[
float
]]]
=
None
,
session_id
:
Optional
[
str
]
=
None
,
):
# Input and output info
...
...
@@ -191,6 +192,7 @@ class Req:
self
.
sampling_params
=
sampling_params
self
.
lora_path
=
lora_path
self
.
input_embeds
=
input_embeds
# Memory pool info
self
.
req_pool_idx
=
None
...
...
@@ -448,6 +450,7 @@ class ScheduleBatch:
# Batched arguments to model runner
input_ids
:
torch
.
Tensor
=
None
input_embeds
:
torch
.
Tensor
=
None
req_pool_indices
:
torch
.
Tensor
=
None
seq_lens
:
torch
.
Tensor
=
None
# The output locations of the KV cache
...
...
@@ -631,6 +634,9 @@ class ScheduleBatch:
req_pool_indices
=
self
.
alloc_req_slots
(
bs
)
out_cache_loc
=
self
.
alloc_token_slots
(
extend_num_tokens
)
input_embeds
=
[]
pt
=
0
for
i
,
req
in
enumerate
(
reqs
):
already_computed
=
(
req
.
extend_logprob_start_len
+
1
+
req
.
cached_tokens
...
...
@@ -649,6 +655,11 @@ class ScheduleBatch:
(
req
.
req_pool_idx
,
slice
(
0
,
pre_len
)),
req
.
prefix_indices
)
# If input_embeds are available, store them
if
req
.
input_embeds
is
not
None
:
# If req.input_embeds is already a list, append its content directly
input_embeds
.
extend
(
req
.
input_embeds
)
# Use extend to avoid nesting
# Compute the relative logprob_start_len in an extend batch
if
req
.
logprob_start_len
>=
pre_len
:
extend_logprob_start_len
=
min
(
...
...
@@ -671,6 +682,12 @@ class ScheduleBatch:
self
.
seq_lens
=
torch
.
tensor
(
seq_lens
,
dtype
=
torch
.
int32
).
to
(
self
.
device
,
non_blocking
=
True
)
self
.
input_embeds
=
(
torch
.
tensor
(
input_embeds
).
to
(
self
.
device
,
non_blocking
=
True
)
if
input_embeds
else
None
)
self
.
out_cache_loc
=
out_cache_loc
self
.
seq_lens_sum
=
sum
(
seq_lens
)
...
...
@@ -1053,6 +1070,7 @@ class ScheduleBatch:
encoder_out_cache_loc
=
self
.
encoder_out_cache_loc
,
lora_paths
=
[
req
.
lora_path
for
req
in
self
.
reqs
],
sampling_info
=
self
.
sampling_info
,
input_embeds
=
self
.
input_embeds
,
)
def
copy
(
self
):
...
...
@@ -1123,6 +1141,9 @@ class ModelWorkerBatch:
# Sampling info
sampling_info
:
SamplingBatchInfo
# The input Embeds
input_embeds
:
Optional
[
torch
.
tensor
]
=
None
@
triton
.
jit
def
write_req_to_token_pool_triton
(
...
...
python/sglang/srt/managers/scheduler.py
View file @
1aea19f6
...
...
@@ -526,12 +526,20 @@ class Scheduler:
recv_req
:
TokenizedGenerateReqInput
,
):
if
recv_req
.
session_id
is
None
or
recv_req
.
session_id
not
in
self
.
sessions
:
# Check if input_embeds is present and create dummy input_ids
if
recv_req
.
input_embeds
is
not
None
:
# Generate fake input_ids based on the length of input_embeds
seq_length
=
len
(
recv_req
.
input_embeds
)
fake_input_ids
=
[
1
]
*
seq_length
recv_req
.
input_ids
=
fake_input_ids
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
,
recv_req
.
sampling_params
,
lora_path
=
recv_req
.
lora_path
,
input_embeds
=
recv_req
.
input_embeds
,
)
req
.
tokenizer
=
self
.
tokenizer
if
recv_req
.
session_id
is
not
None
:
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
1aea19f6
...
...
@@ -201,8 +201,18 @@ class TokenizerManager:
):
"""Tokenize one request."""
# Tokenize
input_embeds
=
None
input_text
=
obj
.
text
if
obj
.
input_ids
is
None
:
if
obj
.
input_embeds
is
not
None
:
if
not
self
.
server_args
.
disable_radix_cache
:
raise
ValueError
(
"input_embeds is provided while disable_radix_cache is False. "
"Please add `--disable-radix-cach` when you launch the server "
"if you want to use input_embeds as inputs."
)
input_embeds
=
obj
.
input_embeds
input_ids
=
obj
.
input_ids
elif
obj
.
input_ids
is
None
:
input_ids
=
self
.
tokenizer
.
encode
(
input_text
)
else
:
input_ids
=
obj
.
input_ids
...
...
@@ -219,7 +229,7 @@ class TokenizerManager:
session_id
=
obj
.
session
[
0
]
if
obj
.
session
else
None
session_rid
=
obj
.
session
[
1
]
if
obj
.
session
else
None
if
len
(
input_ids
)
>=
self
.
context_len
:
if
obj
.
input_ids
is
not
None
and
len
(
input_ids
)
>=
self
.
context_len
:
raise
ValueError
(
f
"The input (
{
len
(
input_ids
)
}
tokens) is longer than the "
f
"model's context length (
{
self
.
context_len
}
tokens)."
...
...
@@ -242,7 +252,8 @@ class TokenizerManager:
logprob_start_len
,
top_logprobs_num
,
obj
.
stream
,
obj
.
lora_path
,
lora_path
=
obj
.
lora_path
,
input_embeds
=
input_embeds
,
session_id
=
session_id
,
session_rid
=
session_rid
,
)
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
1aea19f6
...
...
@@ -130,6 +130,9 @@ class ForwardBatch:
# For LoRA
lora_paths
:
Optional
[
List
[
str
]]
=
None
# For input embeddings
input_embeds
:
Optional
[
torch
.
tensor
]
=
None
# Sampling info
sampling_info
:
SamplingBatchInfo
=
None
...
...
@@ -231,6 +234,7 @@ class ForwardBatch:
can_run_dp_cuda_graph
=
batch
.
can_run_dp_cuda_graph
,
lora_paths
=
batch
.
lora_paths
,
sampling_info
=
batch
.
sampling_info
,
input_embeds
=
batch
.
input_embeds
,
)
if
ret
.
global_num_tokens
is
not
None
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
1aea19f6
...
...
@@ -606,9 +606,17 @@ class ModelRunner:
def
forward_extend
(
self
,
forward_batch
:
ForwardBatch
):
self
.
attn_backend
.
init_forward_metadata
(
forward_batch
)
if
self
.
is_generation
:
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
if
forward_batch
.
input_embeds
is
None
:
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
)
else
:
return
self
.
model
.
forward
(
forward_batch
.
input_ids
,
forward_batch
.
positions
,
forward_batch
,
input_embeds
=
forward_batch
.
input_embeds
.
bfloat16
(),
)
else
:
# Only embedding models have get_embedding parameter
return
self
.
model
.
forward
(
...
...
test/srt/run_suite.py
View file @
1aea19f6
...
...
@@ -14,6 +14,7 @@ suites = {
"test_double_sparsity.py"
,
"test_embedding_openai_server.py"
,
"test_eval_accuracy_mini.py"
,
"test_input_embeddings.py"
,
"test_json_constrained.py"
,
"test_large_max_new_tokens.py"
,
"test_metrics.py"
,
...
...
test/srt/test_input_embeddings.py
0 → 100644
View file @
1aea19f6
import
json
import
unittest
import
requests
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestInputEmbeds
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
cls
.
model
)
cls
.
ref_model
=
AutoModelForCausalLM
.
from_pretrained
(
cls
.
model
)
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
other_args
=
[
"--disable-radix"
],
)
cls
.
texts
=
[
"The capital of France is"
,
"What is the best time of year to visit Japan for cherry blossoms?"
,
]
def
generate_input_embeddings
(
self
,
text
):
"""Generate input embeddings for a given text."""
input_ids
=
self
.
tokenizer
(
text
,
return_tensors
=
"pt"
)[
"input_ids"
]
embeddings
=
self
.
ref_model
.
get_input_embeddings
()(
input_ids
)
return
embeddings
.
squeeze
().
tolist
()
# Convert tensor to a list for API use
def
send_request
(
self
,
payload
):
"""Send a POST request to the API and return the response."""
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
payload
,
timeout
=
30
,
# Set a reasonable timeout for the API request
)
if
response
.
status_code
==
200
:
return
response
.
json
()
return
{
"error"
:
f
"Request failed with status
{
response
.
status_code
}
:
{
response
.
text
}
"
}
def
test_text_based_response
(
self
):
"""Print API response using text-based input."""
for
text
in
self
.
texts
:
payload
=
{
"model"
:
self
.
model
,
"text"
:
text
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
50
},
}
response
=
self
.
send_request
(
payload
)
print
(
f
"Text Input:
{
text
}
\n
Response:
{
json
.
dumps
(
response
,
indent
=
2
)
}
\n
{
'-'
*
80
}
"
)
def
test_embedding_based_response
(
self
):
"""Print API response using input embeddings."""
for
text
in
self
.
texts
:
embeddings
=
self
.
generate_input_embeddings
(
text
)
payload
=
{
"model"
:
self
.
model
,
"input_embeds"
:
embeddings
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
50
},
}
response
=
self
.
send_request
(
payload
)
print
(
f
"Embeddings Input (for text '
{
text
}
'):
\n
Response:
{
json
.
dumps
(
response
,
indent
=
2
)
}
\n
{
'-'
*
80
}
"
)
def
test_compare_text_vs_embedding
(
self
):
"""Print responses for both text-based and embedding-based inputs."""
for
text
in
self
.
texts
:
# Text-based payload
text_payload
=
{
"model"
:
self
.
model
,
"text"
:
text
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
50
},
}
# Embedding-based payload
embeddings
=
self
.
generate_input_embeddings
(
text
)
embed_payload
=
{
"model"
:
self
.
model
,
"input_embeds"
:
embeddings
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
50
},
}
# Get responses
text_response
=
self
.
send_request
(
text_payload
)
embed_response
=
self
.
send_request
(
embed_payload
)
# Print responses
print
(
f
"Text Input:
{
text
}
\n
Text-Based Response:
{
json
.
dumps
(
text_response
,
indent
=
2
)
}
\n
"
)
print
(
f
"Embeddings Input (for text '
{
text
}
'):
\n
Embedding-Based Response:
{
json
.
dumps
(
embed_response
,
indent
=
2
)
}
\n
{
'-'
*
80
}
"
)
self
.
assertEqual
(
text_response
[
"text"
],
embed_response
[
"text"
])
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment