Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
1df6eabd
Unverified
Commit
1df6eabd
authored
Feb 21, 2025
by
Andrew Smith
Committed by
GitHub
Feb 21, 2025
Browse files
feat: Add SageMaker support (#3740)
parent
0c227ee3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
299 additions
and
0 deletions
+299
-0
docker/Dockerfile.sagemaker
docker/Dockerfile.sagemaker
+78
-0
docker/serve
docker/serve
+31
-0
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+12
-0
test/srt/test_sagemaker_server.py
test/srt/test_sagemaker_server.py
+178
-0
No files found.
docker/Dockerfile.sagemaker
0 → 100644
View file @
1df6eabd
ARG CUDA_VERSION=12.5.1
FROM nvcr.io/nvidia/tritonserver:24.04-py3-min
ARG BUILD_TYPE=all
ENV DEBIAN_FRONTEND=noninteractive
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
&& apt update -y \
&& apt install software-properties-common -y \
&& add-apt-repository ppa:deadsnakes/ppa -y && apt update \
&& apt install python3.10 python3.10-dev -y \
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1 \
&& update-alternatives --set python3 /usr/bin/python3.10 && apt install python3.10-distutils -y \
&& apt install curl git sudo libibverbs-dev -y \
&& apt install -y rdma-core infiniband-diags openssh-server perftest ibverbs-providers libibumad3 libibverbs1 libnl-3-200 libnl-route-3-200 librdmacm1 \
&& curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3 get-pip.py \
&& python3 --version \
&& python3 -m pip --version \
&& rm -rf /var/lib/apt/lists/* \
&& apt clean
# For openbmb/MiniCPM models
RUN pip3 install datamodel_code_generator
WORKDIR /sgl-workspace
ARG CUDA_VERSION
RUN python3 -m pip install --upgrade pip setuptools wheel html5lib six \
&& git clone --depth=1 https://github.com/sgl-project/sglang.git \
&& if [ "$CUDA_VERSION" = "12.1.1" ]; then \
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu121; \
elif [ "$CUDA_VERSION" = "12.4.1" ]; then \
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \
elif [ "$CUDA_VERSION" = "12.5.1" ]; then \
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu124; \
elif [ "$CUDA_VERSION" = "11.8.0" ]; then \
python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118; \
python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \
else \
echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \
fi \
&& cd sglang \
&& if [ "$BUILD_TYPE" = "srt" ]; then \
if [ "$CUDA_VERSION" = "12.1.1" ]; then \
python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu121/torch2.5/flashinfer-python; \
elif [ "$CUDA_VERSION" = "12.4.1" ]; then \
python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python; \
elif [ "$CUDA_VERSION" = "12.5.1" ]; then \
python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python; \
elif [ "$CUDA_VERSION" = "11.8.0" ]; then \
python3 -m pip --no-cache-dir install -e "python[srt]" --find-links https://flashinfer.ai/whl/cu118/torch2.5/flashinfer-python; \
python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \
else \
echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \
fi; \
else \
if [ "$CUDA_VERSION" = "12.1.1" ]; then \
python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu121/torch2.5/flashinfer-python; \
elif [ "$CUDA_VERSION" = "12.4.1" ]; then \
python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python; \
elif [ "$CUDA_VERSION" = "12.5.1" ]; then \
python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python; \
elif [ "$CUDA_VERSION" = "11.8.0" ]; then \
python3 -m pip --no-cache-dir install -e "python[all]" --find-links https://flashinfer.ai/whl/cu118/torch2.5/flashinfer-python; \
python3 -m pip install sgl-kernel -i https://docs.sglang.ai/whl/cu118; \
else \
echo "Unsupported CUDA version: $CUDA_VERSION" && exit 1; \
fi; \
fi
ENV DEBIAN_FRONTEND=interactive
COPY serve /usr/bin/serve
RUN chmod 777 /usr/bin/serve
ENTRYPOINT [ "/usr/bin/serve" ]
docker/serve
0 → 100755
View file @
1df6eabd
#!/bin/bash
echo
"Starting server"
SERVER_ARGS
=
"--host 0.0.0.0 --port 8080"
if
[
-n
"
$TENSOR_PARALLEL_DEGREE
"
]
;
then
SERVER_ARGS
=
"
${
SERVER_ARGS
}
--tp-size
${
TENSOR_PARALLEL_DEGREE
}
"
fi
if
[
-n
"
$DATA_PARALLEL_DEGREE
"
]
;
then
SERVER_ARGS
=
"
${
SERVER_ARGS
}
--dp-size
${
DATA_PARALLEL_DEGREE
}
"
fi
if
[
-n
"
$EXPERT_PARALLEL_DEGREE
"
]
;
then
SERVER_ARGS
=
"
${
SERVER_ARGS
}
--ep-size
${
EXPERT_PARALLEL_DEGREE
}
"
fi
if
[
-n
"
$MEM_FRACTION_STATIC
"
]
;
then
SERVER_ARGS
=
"
${
SERVER_ARGS
}
--mem-fraction-static
${
MEM_FRACTION_STATIC
}
"
fi
if
[
-n
"
$QUANTIZATION
"
]
;
then
SERVER_ARGS
=
"
${
SERVER_ARGS
}
--quantization
${
QUANTIZATION
}
"
fi
if
[
-n
"
$CHUNKED_PREFILL_SIZE
"
]
;
then
SERVER_ARGS
=
"
${
SERVER_ARGS
}
--chunked-prefill-size
${
CHUNKED_PREFILL_SIZE
}
"
fi
python3
-m
sglang.launch_server
--model-path
/opt/ml/model
$SERVER_ARGS
python/sglang/srt/entrypoints/http_server.py
View file @
1df6eabd
...
...
@@ -463,6 +463,18 @@ async def retrieve_file_content(file_id: str):
return
await
v1_retrieve_file_content
(
file_id
)
## SageMaker API
@
app
.
get
(
"/ping"
)
async
def
sagemaker_health
()
->
Response
:
"""Check the health of the http server."""
return
Response
(
status_code
=
200
)
@
app
.
post
(
"/invocations"
)
async
def
sagemaker_chat_completions
(
raw_request
:
Request
):
return
await
v1_chat_completions
(
_global_state
.
tokenizer_manager
,
raw_request
)
def
_create_error_response
(
e
):
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
...
...
test/srt/test_sagemaker_server.py
0 → 100644
View file @
1df6eabd
"""
python3 -m unittest test_sagemaker_server.TestSageMakerServer.test_chat_completion
"""
import
json
import
unittest
import
requests
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.utils
import
kill_process_tree
from
sglang.test.test_utils
import
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestSageMakerServer
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
api_key
=
"sk-123456"
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
api_key
=
cls
.
api_key
,
)
cls
.
tokenizer
=
get_tokenizer
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
)
@
classmethod
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
run_chat_completion
(
self
,
logprobs
,
parallel_sample_num
):
data
=
{
"model"
:
self
.
model
,
"messages"
:
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful AI assistant"
},
{
"role"
:
"user"
,
"content"
:
"What is the capital of France? Answer in a few words."
,
},
],
"temperature"
:
0
,
"logprobs"
:
logprobs
is
not
None
and
logprobs
>
0
,
"top_logprobs"
:
logprobs
,
"n"
:
parallel_sample_num
,
}
headers
=
{
"Authorization"
:
f
"Bearer
{
self
.
api_key
}
"
}
response
=
requests
.
post
(
f
"
{
self
.
base_url
}
/invocations"
,
json
=
data
,
headers
=
headers
).
json
()
if
logprobs
:
assert
isinstance
(
response
[
"choices"
][
0
][
"logprobs"
][
"content"
][
0
][
"top_logprobs"
][
0
][
"token"
],
str
,
)
ret_num_top_logprobs
=
len
(
response
[
"choices"
][
0
][
"logprobs"
][
"content"
][
0
][
"top_logprobs"
]
)
assert
(
ret_num_top_logprobs
==
logprobs
),
f
"
{
ret_num_top_logprobs
}
vs
{
logprobs
}
"
assert
len
(
response
[
"choices"
])
==
parallel_sample_num
assert
response
[
"choices"
][
0
][
"message"
][
"role"
]
==
"assistant"
assert
isinstance
(
response
[
"choices"
][
0
][
"message"
][
"content"
],
str
)
assert
response
[
"id"
]
assert
response
[
"created"
]
assert
response
[
"usage"
][
"prompt_tokens"
]
>
0
assert
response
[
"usage"
][
"completion_tokens"
]
>
0
assert
response
[
"usage"
][
"total_tokens"
]
>
0
def
run_chat_completion_stream
(
self
,
logprobs
,
parallel_sample_num
=
1
):
data
=
{
"model"
:
self
.
model
,
"messages"
:
[
{
"role"
:
"system"
,
"content"
:
"You are a helpful AI assistant"
},
{
"role"
:
"user"
,
"content"
:
"What is the capital of France? Answer in a few words."
,
},
],
"temperature"
:
0
,
"logprobs"
:
logprobs
is
not
None
and
logprobs
>
0
,
"top_logprobs"
:
logprobs
,
"stream"
:
True
,
"stream_options"
:
{
"include_usage"
:
True
},
"n"
:
parallel_sample_num
,
}
headers
=
{
"Authorization"
:
f
"Bearer
{
self
.
api_key
}
"
}
response
=
requests
.
post
(
f
"
{
self
.
base_url
}
/invocations"
,
json
=
data
,
stream
=
True
,
headers
=
headers
)
is_firsts
=
{}
for
line
in
response
.
iter_lines
():
line
=
line
.
decode
(
"utf-8"
).
replace
(
"data: "
,
""
)
if
len
(
line
)
<
1
or
line
==
"[DONE]"
:
continue
print
(
f
"value:
{
line
}
"
)
line
=
json
.
loads
(
line
)
usage
=
line
.
get
(
"usage"
)
if
usage
is
not
None
:
assert
usage
[
"prompt_tokens"
]
>
0
assert
usage
[
"completion_tokens"
]
>
0
assert
usage
[
"total_tokens"
]
>
0
continue
index
=
line
.
get
(
"choices"
)[
0
].
get
(
"index"
)
data
=
line
.
get
(
"choices"
)[
0
].
get
(
"delta"
)
if
is_firsts
.
get
(
index
,
True
):
assert
data
[
"role"
]
==
"assistant"
is_firsts
[
index
]
=
False
continue
if
logprobs
:
assert
line
.
get
(
"choices"
)[
0
].
get
(
"logprobs"
)
assert
isinstance
(
line
.
get
(
"choices"
)[
0
]
.
get
(
"logprobs"
)
.
get
(
"content"
)[
0
]
.
get
(
"top_logprobs"
)[
0
]
.
get
(
"token"
),
str
,
)
assert
isinstance
(
line
.
get
(
"choices"
)[
0
]
.
get
(
"logprobs"
)
.
get
(
"content"
)[
0
]
.
get
(
"top_logprobs"
),
list
,
)
ret_num_top_logprobs
=
len
(
line
.
get
(
"choices"
)[
0
]
.
get
(
"logprobs"
)
.
get
(
"content"
)[
0
]
.
get
(
"top_logprobs"
)
)
assert
(
ret_num_top_logprobs
==
logprobs
),
f
"
{
ret_num_top_logprobs
}
vs
{
logprobs
}
"
assert
isinstance
(
data
[
"content"
],
str
)
assert
line
[
"id"
]
assert
line
[
"created"
]
for
index
in
[
i
for
i
in
range
(
parallel_sample_num
)]:
assert
not
is_firsts
.
get
(
index
,
True
),
f
"index
{
index
}
is not found in the response"
def
test_chat_completion
(
self
):
for
logprobs
in
[
None
,
5
]:
for
parallel_sample_num
in
[
1
,
2
]:
self
.
run_chat_completion
(
logprobs
,
parallel_sample_num
)
def
test_chat_completion_stream
(
self
):
for
logprobs
in
[
None
,
5
]:
for
parallel_sample_num
in
[
1
,
2
]:
self
.
run_chat_completion_stream
(
logprobs
,
parallel_sample_num
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment