Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
76619261
Unverified
Commit
76619261
authored
Nov 18, 2024
by
Yineng Zhang
Committed by
GitHub
Nov 18, 2024
Browse files
feat: update torch 2.5.1 (#2069)
parent
2a3992b6
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
127 additions
and
33 deletions
+127
-33
Makefile
Makefile
+12
-0
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/srt/layers/activation.py
python/sglang/srt/layers/activation.py
+3
-0
python/sglang/srt/layers/custom_op_util.py
python/sglang/srt/layers/custom_op_util.py
+26
-0
python/sglang/srt/layers/layernorm.py
python/sglang/srt/layers/layernorm.py
+4
-0
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+2
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+50
-32
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+22
-0
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+6
-0
test/srt/test_bench_serving.py
test/srt/test_bench_serving.py
+1
-0
No files found.
Makefile
0 → 100644
View file @
76619261
.PHONY
:
check-deps install-deps format
check-deps
:
@
command
-v
isort
>
/dev/null 2>&1
||
(
echo
"Installing isort..."
&&
pip
install
isort
)
@
command
-v
black
>
/dev/null 2>&1
||
(
echo
"Installing black..."
&&
pip
install
black
)
install-deps
:
pip
install
isort black
format
:
check-deps
@
echo
"Formatting modified Python files..."
git diff
--name-only
--diff-filter
=
M |
grep
'\.py$$'
| xargs
-I
{}
sh
-c
'isort {} && black {}'
python/pyproject.toml
View file @
76619261
...
...
@@ -20,7 +20,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
"orjson"
,
"packaging"
,
"pillow"
,
"prometheus-client>=0.20.0"
,
"psutil"
,
"pydantic"
,
"python-multipart"
,
"torchao"
,
"uvicorn"
,
"uvloop"
,
"pyzmq>=25.1.2"
,
"outlines>=0.0.44,<0.1.0"
,
"modelscope"
]
srt
=
["sglang[runtime_common]
", "
torch
", "
vllm
=
=
0.6.3
.post
1
"]
srt
=
["sglang[runtime_common]
", "
torch
", "
vllm
>
=
0.6.3
.post
1
"]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
...
...
python/sglang/srt/layers/activation.py
View file @
76619261
...
...
@@ -32,12 +32,14 @@ from vllm.distributed import (
)
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.custom_op_util
import
register_custom_op
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.utils
import
set_weight_attrs
logger
=
logging
.
getLogger
(
__name__
)
@
register_custom_op
(
"sglang_silu_and_mul"
)
class
SiluAndMul
(
CustomOp
):
def
forward_native
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
d
=
x
.
shape
[
-
1
]
//
2
...
...
@@ -51,6 +53,7 @@ class SiluAndMul(CustomOp):
return
out
@
register_custom_op
(
"sglang_gelu_and_mul"
)
class
GeluAndMul
(
CustomOp
):
def
__init__
(
self
,
approximate
=
"tanh"
):
super
().
__init__
()
...
...
python/sglang/srt/layers/custom_op_util.py
0 → 100644
View file @
76619261
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from
vllm.model_executor.custom_op
import
CustomOp
def
register_custom_op
(
op_name
):
def
decorator
(
cls
):
if
hasattr
(
CustomOp
,
"register"
):
return
CustomOp
.
register
(
op_name
)(
cls
)
else
:
return
cls
return
decorator
python/sglang/srt/layers/layernorm.py
View file @
76619261
...
...
@@ -33,9 +33,12 @@ if is_flashinfer_available():
from
vllm.model_executor.custom_op
import
CustomOp
from
sglang.srt.layers.custom_op_util
import
register_custom_op
logger
=
logging
.
getLogger
(
__name__
)
@
register_custom_op
(
"sglang_rmsnorm"
)
class
RMSNorm
(
CustomOp
):
def
__init__
(
self
,
...
...
@@ -78,6 +81,7 @@ class RMSNorm(CustomOp):
return
x
,
residual
@
register_custom_op
(
"sglang_gemma_rmsnorm"
)
class
GemmaRMSNorm
(
CustomOp
):
def
__init__
(
self
,
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
76619261
...
...
@@ -90,6 +90,8 @@ def set_torch_compile_config():
# FIXME: tmp workaround
torch
.
_dynamo
.
config
.
accumulated_cache_size_limit
=
1024
if
hasattr
(
torch
.
_dynamo
.
config
,
"cache_size_limit"
):
torch
.
_dynamo
.
config
.
cache_size_limit
=
1024
@
maybe_torch_compile
(
dynamic
=
True
)
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
76619261
...
...
@@ -18,9 +18,9 @@ limitations under the License.
import
gc
import
importlib
import
importlib.resources
import
inspect
import
json
import
logging
import
os
import
pkgutil
from
functools
import
lru_cache
from
typing
import
Optional
,
Type
...
...
@@ -60,6 +60,7 @@ from sglang.srt.utils import (
crash_on_warnings
,
enable_show_time_cost
,
get_available_gpu_memory
,
monkey_patch_vllm_model_config
,
monkey_patch_vllm_p2p_access_check
,
)
...
...
@@ -226,6 +227,47 @@ class ModelRunner:
return
min_per_gpu_memory
def
setup_model
(
self
):
try
:
from
vllm.config
import
VllmConfig
vllm_config
=
VllmConfig
()
vllm_config
.
model_config
=
self
.
vllm_model_config
vllm_config
.
load_config
=
self
.
load_config
vllm_config
.
device_config
=
DeviceConfig
(
self
.
device
)
vllm_config
.
quant_config
=
VllmConfig
.
_get_quantization_config
(
vllm_config
.
model_config
,
vllm_config
.
load_config
)
return
get_model
(
vllm_config
=
vllm_config
)
except
ImportError
:
return
get_model
(
model_config
=
self
.
vllm_model_config
,
load_config
=
self
.
load_config
,
device_config
=
DeviceConfig
(
self
.
device
),
parallel_config
=
None
,
scheduler_config
=
None
,
lora_config
=
None
,
cache_config
=
None
,
)
def
get_model_config_params
(
self
):
sig
=
inspect
.
signature
(
VllmModelConfig
.
__init__
)
params
=
{
"model"
:
self
.
server_args
.
model_path
,
"quantization"
:
self
.
server_args
.
quantization
,
"tokenizer"
:
None
,
"tokenizer_mode"
:
None
,
"trust_remote_code"
:
self
.
server_args
.
trust_remote_code
,
"dtype"
:
self
.
server_args
.
dtype
,
"seed"
:
self
.
server_args
.
random_seed
,
"skip_tokenizer_init"
:
True
,
}
if
"task"
in
sig
.
parameters
:
params
[
"task"
]
=
""
return
params
def
load_model
(
self
):
logger
.
info
(
f
"Load weight begin. avail mem=
{
get_available_gpu_memory
(
self
.
device
,
self
.
gpu_id
):.
2
f
}
GB"
...
...
@@ -247,31 +289,15 @@ class ModelRunner:
load_format
=
self
.
server_args
.
load_format
,
download_dir
=
self
.
server_args
.
download_dir
,
)
self
.
vllm_model_config
=
VllmModelConfig
(
model
=
self
.
server_args
.
model_path
,
quantization
=
self
.
server_args
.
quantization
,
tokenizer
=
None
,
tokenizer_mode
=
None
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
dtype
=
self
.
server_args
.
dtype
,
seed
=
self
.
server_args
.
random_seed
,
skip_tokenizer_init
=
True
,
)
monkey_patch_vllm_model_config
()
self
.
vllm_model_config
=
VllmModelConfig
(
**
self
.
get_model_config_params
())
if
self
.
model_config
.
model_override_args
is
not
None
:
self
.
vllm_model_config
.
hf_config
.
update
(
self
.
model_config
.
model_override_args
)
# Load the model
self
.
model
=
get_model
(
model_config
=
self
.
vllm_model_config
,
load_config
=
self
.
load_config
,
device_config
=
DeviceConfig
(
self
.
device
),
parallel_config
=
None
,
scheduler_config
=
None
,
lora_config
=
None
,
cache_config
=
None
,
)
self
.
model
=
self
.
setup_model
()
self
.
sliding_window_size
=
(
self
.
model
.
get_attention_sliding_window_size
()
if
hasattr
(
self
.
model
,
"get_attention_sliding_window_size"
)
...
...
@@ -303,17 +329,9 @@ class ModelRunner:
target_device
=
torch
.
device
(
self
.
device
)
try
:
# TODO: Use a better method to check this
vllm_model_config
=
VllmModelConfig
(
model
=
model_path
,
quantization
=
self
.
server_args
.
quantization
,
tokenizer
=
None
,
tokenizer_mode
=
None
,
trust_remote_code
=
self
.
server_args
.
trust_remote_code
,
dtype
=
self
.
server_args
.
dtype
,
seed
=
self
.
server_args
.
random_seed
,
skip_tokenizer_init
=
True
,
)
model_config_params
=
self
.
get_model_config_params
()
model_config_params
[
"model"
]
=
model_path
vllm_model_config
=
VllmModelConfig
(
**
model_config_params
)
except
Exception
as
e
:
message
=
f
"Failed to load model config:
{
e
}
."
return
False
,
message
...
...
python/sglang/srt/utils.py
View file @
76619261
...
...
@@ -332,6 +332,7 @@ def suppress_other_loggers():
)
logging
.
getLogger
(
"vllm.selector"
).
setLevel
(
logging
.
WARN
)
logging
.
getLogger
(
"vllm.utils"
).
setLevel
(
logging
.
ERROR
)
logging
.
getLogger
(
"vllm.model_executor.model_loader.loader"
).
setLevel
(
logging
.
ERROR
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
UserWarning
,
message
=
"The given NumPy array is not writable"
...
...
@@ -396,6 +397,27 @@ def kill_child_process(pid=None, include_self=False, skip_pid=None):
pass
def
monkey_patch_vllm_model_config
():
from
vllm.config
import
ModelConfig
if
not
hasattr
(
ModelConfig
,
"_resolve_task"
):
return
def
_resolve_task
(
self
,
task_option
,
hf_config
,
):
supported_tasks
=
{
"generate"
:
True
,
"embedding"
:
False
,
}
selected_task
=
"generate"
return
supported_tasks
,
selected_task
setattr
(
ModelConfig
,
"_resolve_task"
,
_resolve_task
)
def
monkey_patch_vllm_p2p_access_check
(
gpu_id
:
int
):
"""
Monkey patch the slow p2p access check in vllm.
...
...
python/sglang/test/test_utils.py
View file @
76619261
...
...
@@ -2,6 +2,7 @@
import
argparse
import
asyncio
import
copy
import
os
import
random
import
subprocess
...
...
@@ -529,6 +530,7 @@ def run_bench_serving(
random_input_len
=
4096
,
random_output_len
=
2048
,
disable_stream
=
False
,
need_warmup
=
False
,
):
# Launch the server
base_url
=
DEFAULT_URL_FOR_TEST
...
...
@@ -565,6 +567,10 @@ def run_bench_serving(
)
try
:
if
need_warmup
:
warmup_args
=
copy
.
deepcopy
(
args
)
warmup_args
.
num_prompts
=
16
run_benchmark
(
warmup_args
)
res
=
run_benchmark
(
args
)
finally
:
kill_child_process
(
process
.
pid
,
include_self
=
True
)
...
...
test/srt/test_bench_serving.py
View file @
76619261
...
...
@@ -32,6 +32,7 @@ class TestBenchServing(unittest.TestCase):
random_input_len
=
None
,
random_output_len
=
None
,
disable_stream
=
True
,
need_warmup
=
True
,
)
if
is_in_ci
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment