Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
20a4f927
"vscode:/vscode.git/clone" did not exist on "0cb099e20a0b9ccd308fff5ef133a2e4b26a7f7a"
Unverified
Commit
20a4f927
authored
Aug 08, 2024
by
Ying Sheng
Committed by
GitHub
Aug 08, 2024
Browse files
Add io struct for embedding models [unreachable code] - step 2/3 (#987)
parent
0de7c2d0
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
146 additions
and
4 deletions
+146
-4
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+60
-0
python/sglang/srt/models/llama2.py
python/sglang/srt/models/llama2.py
+2
-2
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+68
-2
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+16
-0
No files found.
python/sglang/srt/managers/io_struct.py
View file @
20a4f927
...
...
@@ -22,6 +22,8 @@ import uuid
from
dataclasses
import
dataclass
from
typing
import
Dict
,
List
,
Optional
,
Union
import
torch
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.sampling_params
import
SamplingParams
...
...
@@ -166,6 +168,56 @@ class TokenizedGenerateReqInput:
stream
:
bool
@
dataclass
class
EmbeddingReqInput
:
# The input prompt. It can be a single prompt or a batch of prompts.
text
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# The token ids for text; one can either specify text or input_ids.
input_ids
:
Optional
[
Union
[
List
[
List
[
int
]],
List
[
int
]]]
=
None
# The request id.
rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
# Dummy sampling params for compatibility
sampling_params
:
Union
[
List
[
Dict
],
Dict
]
=
None
def
post_init
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
):
raise
ValueError
(
"Either text or input_ids should be provided."
)
if
self
.
text
is
not
None
:
is_single
=
isinstance
(
self
.
text
,
str
)
else
:
is_single
=
isinstance
(
self
.
input_ids
[
0
],
int
)
self
.
is_single
=
is_single
if
is_single
:
if
self
.
rid
is
None
:
self
.
rid
=
uuid
.
uuid4
().
hex
self
.
sampling_params
=
{
"max_new_tokens"
:
0
}
else
:
# support select operation
self
.
batch_size
=
(
len
(
self
.
text
)
if
self
.
text
is
not
None
else
len
(
self
.
input_ids
)
)
if
self
.
rid
is
None
:
self
.
rid
=
[
uuid
.
uuid4
().
hex
for
_
in
range
(
self
.
batch_size
)]
else
:
if
not
isinstance
(
self
.
rid
,
list
):
raise
ValueError
(
"The rid should be a list."
)
self
.
sampling_params
=
[
{
"max_new_tokens"
:
0
}
for
_
in
range
(
self
.
batch_size
)
]
@
dataclass
class
TokenizedEmbeddingReqInput
:
rid
:
str
input_text
:
str
input_ids
:
List
[
int
]
sampling_params
:
SamplingParams
@
dataclass
class
BatchTokenIDOut
:
rids
:
List
[
str
]
...
...
@@ -187,6 +239,14 @@ class BatchStrOut:
finished_reason
:
List
[
BaseFinishReason
]
@
dataclass
class
BatchEmbeddingOut
:
rids
:
List
[
str
]
embeddings
:
List
[
List
[
float
]]
meta_info
:
List
[
Dict
]
finished_reason
:
List
[
BaseFinishReason
]
@
dataclass
class
FlushCacheReq
:
pass
...
...
python/sglang/srt/models/llama2.py
View file @
20a4f927
...
...
@@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.logits_processor
import
LogitsProcessor
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
,
LogitsProcessor
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
...
...
@@ -310,7 +310,7 @@ class LlamaForCausalLM(nn.Module):
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
torch
.
Tensor
:
)
->
LogitProcessorOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
logits_processor
(
input_ids
,
hidden_states
,
self
.
lm_head
.
weight
,
input_metadata
...
...
python/sglang/srt/openai_api/adapter.py
View file @
20a4f927
...
...
@@ -52,6 +52,8 @@ from sglang.srt.openai_api.protocol import (
CompletionResponseStreamChoice
,
CompletionStreamResponse
,
DeltaMessage
,
EmbeddingRequest
,
EmbeddingResponse
,
ErrorResponse
,
FileDeleteResponse
,
FileRequest
,
...
...
@@ -357,7 +359,6 @@ async def v1_retrieve_file_content(file_id: str):
def
v1_generate_request
(
all_requests
):
prompts
=
[]
sampling_params_list
=
[]
return_logprobs
=
[]
...
...
@@ -648,7 +649,6 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
def
v1_chat_generate_request
(
all_requests
,
tokenizer_manager
):
input_ids
=
[]
sampling_params_list
=
[]
image_data_list
=
[]
...
...
@@ -961,6 +961,72 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
return
response
def
v1_embedding_request
(
all_requests
,
tokenizer_manager
):
prompts
=
[]
sampling_params_list
=
[]
first_prompt_type
=
type
(
all_requests
[
0
].
prompt
)
for
request
in
all_requests
:
prompt
=
request
.
prompt
assert
(
type
(
prompt
)
==
first_prompt_type
),
"All prompts must be of the same type in file input settings"
prompts
.
append
(
prompt
)
if
len
(
all_requests
)
==
1
:
prompt
=
prompts
[
0
]
if
isinstance
(
prompt
,
str
)
or
isinstance
(
prompt
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompt
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompt
}
else
:
if
isinstance
(
prompts
[
0
],
str
)
or
isinstance
(
propmt
[
0
][
0
],
str
):
prompt_kwargs
=
{
"text"
:
prompts
}
else
:
prompt_kwargs
=
{
"input_ids"
:
prompts
}
adapted_request
=
EmbeddingReqInput
(
**
prompt_kwargs
,
)
if
len
(
all_requests
)
==
1
:
return
adapted_request
,
all_requests
[
0
]
return
adapted_request
,
all_requests
def
v1_embedding_response
(
request
,
ret
,
to_file
=
False
):
response
=
[]
for
idx
,
ret_item
in
enumerate
(
ret
):
response
.
append
(
EmbeddingResponse
(
index
=
idx
,
embedding
=
ret
[
idx
],
object
=
"embedding"
,
)
)
return
response
async
def
v1_embeddings
(
tokenizer_manager
,
raw_request
:
Request
):
request_json
=
await
raw_request
.
json
()
all_requests
=
[
EmbeddingRequest
(
**
request_json
)]
adapted_request
,
request
=
v1_embedding_request
(
all_requests
,
tokenizer_manager
)
try
:
ret
=
await
tokenizer_manager
.
generate_request
(
adapted_request
,
raw_request
).
__anext__
()
except
ValueError
as
e
:
return
create_error_response
(
str
(
e
))
if
not
isinstance
(
ret
,
list
):
ret
=
[
ret
]
response
=
v1_embedding_response
(
request
,
ret
)
return
response
def
to_openai_style_logprobs
(
input_token_logprobs
=
None
,
output_token_logprobs
=
None
,
...
...
python/sglang/srt/openai_api/protocol.py
View file @
20a4f927
...
...
@@ -294,3 +294,19 @@ class ChatCompletionStreamResponse(BaseModel):
created
:
int
=
Field
(
default_factory
=
lambda
:
int
(
time
.
time
()))
model
:
str
choices
:
List
[
ChatCompletionResponseStreamChoice
]
class
EmbeddingRequest
(
BaseModel
):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings/create
input
:
Union
[
List
[
int
],
List
[
List
[
int
]],
str
,
List
[
str
]]
model
:
str
encoding_format
:
str
=
"float"
dimensions
:
int
=
None
user
:
Optional
[
str
]
=
None
class
EmbeddingResponse
(
BaseModel
):
index
:
str
embedding
:
List
[
float
]
=
None
object
:
str
=
"embedding"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment