Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
20a4f927
Unverified
Commit
20a4f927
authored
Aug 08, 2024
by
Ying Sheng
Committed by
GitHub
Aug 08, 2024
Browse files
Add io struct for embedding models [unreachable code] - step 2/3 (#987)
parent
0de7c2d0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
146 additions
and
4 deletions
+146
-4
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+60
-0
python/sglang/srt/models/llama2.py
python/sglang/srt/models/llama2.py
+2
-2
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+68
-2
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+16
-0
No files found.
python/sglang/srt/managers/io_struct.py
View file @
20a4f927
...
...
@@ -22,6 +22,8 @@ import uuid
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Union
import
torch
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.sampling_params
import
SamplingParams
...
...
@@ -166,6 +168,56 @@ class TokenizedGenerateReqInput:
stream
:
bool
@
dataclass
class
EmbeddingReqInput
:
# The input prompt. It can be a single prompt or a batch of prompts.
text
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The token ids for text; one can either specify text or input_ids.
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
# The request id.
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Dummy sampling params for compatibility
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
def
post_init
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
):
raise
ValueError
(
"Either text or input_ids should be provided."
)
if
self
.
text
is
not
None
:
is_single
=
isinstance
(
self
.
text
,
str
)
else
:
is_single
=
isinstance
(
self
.
input_ids
[
0
],
int
)
self
.
is_single
=
is_single
if
is_single
:
if
self
.
rid
is
None
:
self
.
rid
=
uuid
.
uuid4
().
hex
self
.
sampling_params
=
{
"max_new_tokens"
:
0
}
else
:
# support select operation
self
.
batch_size
=
(
len
(
self
.
text
)
if
self
.
text
is
not
None
else
len
(
self
.
input_ids
)
)
if
self
.
rid
is
None
:
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
self
.
batch_size
)]
else
:
if
not
isinstance
(
self
.
rid
,
list
):
raise
ValueError
(
"The rid should be a list."
)
self
.
sampling_params
=
[
{
"max_new_tokens"
:
0
}
for
_
in
range
(
self
.
batch_size
)
]
@
dataclass
class
TokenizedEmbeddingReqInput
:
rid
:
str
input_text
:
str
input_ids
:
List
[
int
]
sampling_params
:
SamplingParams
@
dataclass
class
BatchTokenIDOut
:
rids
:
List
[
str
]
...
...
@@ -187,6 +239,14 @@ class BatchStrOut:
finished_reason
:
List
[
BaseFinishReason
]
@
dataclass
class
BatchEmbeddingOut
:
rids
:
List
[
str
]
embeddings
:
List
[
List
[
float
]]
meta_info
:
List
[
Dict
]
finished_reason
:
List
[
BaseFinishReason
]
@
dataclass
class
FlushCacheReq
:
pass
...
...
python/sglang/srt/models/llama2.py
View file @
20a4f927
...
...
@@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
,
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -310,7 +310,7 @@ class LlamaForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
LogitProcessorOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
...
...
python/sglang/srt/openai_api/adapter.py
View file @
20a4f927
...
...
@@ -52,6 +52,8 @@ from sglang.srt.openai_api.protocol import (
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
DeltaMessage
,
EmbeddingRequest
,
EmbeddingResponse
,
ErrorResponse
,
FileDeleteResponse
,
FileRequest
,
...
...
@@ -357,7 +359,6 @@ async def v1_retrieve_file_content(file_id: str):
def
v1_generate_request
(
all_requests
):
prompts
=
[]
sampling_params_list
=
[]
return_logprobs
=
[]
...
...
@@ -648,7 +649,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
def
v1_chat_generate_request
(
all_requests
,
tokenizer_manager
):
input_ids
=
[]
sampling_params_list
=
[]
image_data_list
=
[]
...
...
@@ -961,6 +961,72 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
return
response
def
v1_embedding_request
(
all_requests
,
tokenizer_manager
):
prompts
=
[]
sampling_params_list
=
[]
first_prompt_type
=
type
(
all_requests
[
0
].
prompt
)
for
request
in
all_requests
:
prompt
=
request
.
prompt
assert
(
type
(
prompt
)
==
first_prompt_type
),
"All prompts must be of the same type in file input settings"
prompts
.
append
(
prompt
)
if
len
(
all_requests
)
==
1
:
prompt
=
prompts
[
0
]
if
isinstance
(
prompt
,
str
)
or
isinstance
(
prompt
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompt
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompt
}
else
:
if
isinstance
(
prompts
[
0
],
str
)
or
isinstance
(
propmt
[
0
][
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompts
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompts
}
adapted_request
=
EmbeddingReqInput
(
**
prompt_kwargs
,
)
if
len
(
all_requests
)
==
1
:
return
adapted_request
,
all_requests
[
0
]
return
adapted_request
,
all_requests
def
v1_embedding_response
(
request
,
ret
,
to_file
=
False
):
response
=
[]
for
idx
,
ret_item
in
enumerate
(
ret
):
response
.
append
(
EmbeddingResponse
(
index
=
idx
,
embedding
=
ret
[
idx
],
object
=
"embedding"
,
)
)
return
response
async
def
v1_embeddings
(
tokenizer_manager
,
raw_request
:
Request
):
request_json
=
await
raw_request
.
json
()
all_requests
=
[
EmbeddingRequest
(
**
request_json
)]
adapted_request
,
request
=
v1_embedding_request
(
all_requests
,
tokenizer_manager
)
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
).
__anext__
()
except
ValueError
as
e
:
return
create_error_response
(
str
(
e
))
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
response
=
v1_embedding_response
(
request
,
ret
)
return
response
def
to_openai_style_logprobs
(
input_token_logprobs
=
None
,
output_token_logprobs
=
None
,
...
...
python/sglang/srt/openai_api/protocol.py
View file @
20a4f927
...
...
@@ -294,3 +294,19 @@ class ChatCompletionStreamResponse(BaseModel):
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
ChatCompletionResponseStreamChoice
]
class
EmbeddingRequest
(
BaseModel
):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings/create
input
:
Union
[
List
[
int
],
List
[
List
[
int
]],
str
,
List
[
str
]]
model
:
str
encoding_format
:
str
=
"float"
dimensions
:
int
=
None
user
:
Optional
[
str
]
=
None
class
EmbeddingResponse
(
BaseModel
):
index
:
str
embedding
:
List
[
float
]
=
None
object
:
str
=
"embedding"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment