Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f6af3a65
Unverified
Commit
f6af3a65
authored
Aug 24, 2024
by
Lianmin Zheng
Committed by
GitHub
Aug 24, 2024
Browse files
Cleanup readme, llava examples, usage examples and nccl init (#1194)
parent
c9064e6f
Changes
65
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
110 additions
and
199 deletions
+110
-199
examples/runtime/llava_onevision/http_qwen_llava_test.py
examples/runtime/llava_onevision/http_qwen_llava_test.py
+1
-2
examples/runtime/openai_batch_chat.py
examples/runtime/openai_batch_chat.py
+0
-0
examples/runtime/openai_batch_complete.py
examples/runtime/openai_batch_complete.py
+0
-0
examples/usage/llava/srt_llava_next_test.py
examples/usage/llava/srt_llava_next_test.py
+0
-90
examples/usage/rag_using_parea/max-tokens-fixed-rag-trace.png
...ples/usage/rag_using_parea/max-tokens-fixed-rag-trace.png
+0
-0
python/sglang/bench_latency.py
python/sglang/bench_latency.py
+5
-1
python/sglang/lang/chat_template.py
python/sglang/lang/chat_template.py
+2
-2
python/sglang/launch_server_llavavid.py
python/sglang/launch_server_llavavid.py
+0
-29
python/sglang/srt/layers/decode_attention.py
python/sglang/srt/layers/decode_attention.py
+1
-1
python/sglang/srt/layers/fused_moe/layer.py
python/sglang/srt/layers/fused_moe/layer.py
+2
-2
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+2
-2
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-1
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+3
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+55
-31
python/sglang/srt/models/gemma2.py
python/sglang/srt/models/gemma2.py
+8
-4
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+13
-4
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+9
-10
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+1
-5
python/sglang/test/runners.py
python/sglang/test/runners.py
+7
-15
scripts/deprecated/convert_yi_vl.py
scripts/deprecated/convert_yi_vl.py
+0
-0
No files found.
examples/
usag
e/llava/http_qwen_llava_test.py
→
examples/
runtim
e/llava
_onevision
/http_qwen_llava_test.py
View file @
f6af3a65
...
@@ -4,7 +4,7 @@ Usage:
...
@@ -4,7 +4,7 @@ Usage:
# Installing latest sglang.
# Installing latest sglang.
# Endpoint Service CLI:
# Endpoint Service CLI:
#
python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --
tokenizer-path lmms-lab/llavanext-qwen-tokenizer --port=30000 --host="127.0.0.1"
--tp-size=
4
python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --
port=30000
--tp-size=
8
python3 http_qwen_llava_test.py
python3 http_qwen_llava_test.py
...
@@ -16,7 +16,6 @@ import argparse
...
@@ -16,7 +16,6 @@ import argparse
import
asyncio
import
asyncio
import
copy
import
copy
import
json
import
json
import
time
import
aiohttp
import
aiohttp
import
requests
import
requests
...
...
examples/
usag
e/openai_batch_chat.py
→
examples/
runtim
e/openai_batch_chat.py
View file @
f6af3a65
File moved
examples/
usag
e/openai_batch_complete.py
→
examples/
runtim
e/openai_batch_complete.py
View file @
f6af3a65
File moved
examples/usage/llava/srt_llava_next_test.py
deleted
100644 → 0
View file @
c9064e6f
"""
Usage: python3 srt_example_llava.py
"""
from
PIL
import
ImageFile
import
sglang
as
sgl
from
sglang.lang.chat_template
import
get_chat_template
from
sglang.srt.utils
import
load_image
ImageFile
.
LOAD_TRUNCATED_IMAGES
=
True
# Allow loading of truncated images
@
sgl
.
function
def
image_qa
(
s
,
image
,
question
):
s
+=
sgl
.
user
(
sgl
.
image
(
image
)
+
question
)
s
+=
sgl
.
assistant
(
sgl
.
gen
(
"answer"
))
def
single
():
image_url
=
"https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
pil_image
,
_
=
load_image
(
image_url
)
state
=
image_qa
.
run
(
image
=
pil_image
,
question
=
"What is this?"
,
max_new_tokens
=
512
)
print
(
state
[
"answer"
],
"
\n
"
)
def
stream
():
image_url
=
"https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
pil_image
,
_
=
load_image
(
image_url
)
state
=
image_qa
.
run
(
image
=
pil_image
,
question
=
"Please generate short caption for this image."
,
max_new_tokens
=
512
,
temperature
=
0
,
stream
=
True
,
)
for
out
in
state
.
text_iter
(
"answer"
):
print
(
out
,
end
=
""
,
flush
=
True
)
print
()
def
batch
():
image_url
=
"https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
pil_image
,
_
=
load_image
(
image_url
)
states
=
image_qa
.
run_batch
(
[
{
"image"
:
pil_image
,
"question"
:
"What is this?"
},
{
"image"
:
pil_image
,
"question"
:
"What is this?"
},
],
max_new_tokens
=
512
,
)
for
s
in
states
:
print
(
s
[
"answer"
],
"
\n
"
)
if
__name__
==
"__main__"
:
import
multiprocessing
as
mp
mp
.
set_start_method
(
"spawn"
,
force
=
True
)
runtime
=
sgl
.
Runtime
(
model_path
=
"lmms-lab/llama3-llava-next-8b"
,
tokenizer_path
=
"lmms-lab/llama3-llava-next-8b-tokenizer"
,
)
runtime
.
endpoint
.
chat_template
=
get_chat_template
(
"llama-3-instruct"
)
# runtime = sgl.Runtime(
# model_path="lmms-lab/llava-next-72b",
# tokenizer_path="lmms-lab/llavanext-qwen-tokenizer",
# )
# runtime.endpoint.chat_template = get_chat_template("chatml-llava")
sgl
.
set_default_backend
(
runtime
)
print
(
f
"chat template:
{
runtime
.
endpoint
.
chat_template
.
name
}
"
)
# Or you can use API models
# sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview"))
# sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))
# Run a single request
print
(
"
\n
========== single ==========
\n
"
)
single
()
# Stream output
print
(
"
\n
========== stream ==========
\n
"
)
stream
()
# Run a batch of requests
print
(
"
\n
========== batch ==========
\n
"
)
batch
()
runtime
.
shutdown
()
examples/usage/rag_using_parea/max-tokens-fixed-rag-trace.png
deleted
100644 → 0
View file @
c9064e6f
132 KB
python/sglang/bench_latency.py
View file @
f6af3a65
...
@@ -111,7 +111,11 @@ def load_model(server_args, tp_rank):
...
@@ -111,7 +111,11 @@ def load_model(server_args, tp_rank):
suppress_other_loggers
()
suppress_other_loggers
()
rank_print
=
print
if
tp_rank
==
0
else
lambda
*
args
,
**
kwargs
:
None
rank_print
=
print
if
tp_rank
==
0
else
lambda
*
args
,
**
kwargs
:
None
model_config
=
ModelConfig
(
path
=
server_args
.
model_path
)
model_config
=
ModelConfig
(
server_args
.
model_path
,
server_args
.
trust_remote_code
,
context_length
=
server_args
.
context_length
,
)
model_runner
=
ModelRunner
(
model_runner
=
ModelRunner
(
model_config
=
model_config
,
model_config
=
model_config
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
mem_fraction_static
=
server_args
.
mem_fraction_static
,
...
...
python/sglang/lang/chat_template.py
View file @
f6af3a65
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Callable
,
Dict
,
List
,
Tuple
class
ChatTemplateStyle
(
Enum
):
class
ChatTemplateStyle
(
Enum
):
...
...
python/sglang/launch_server_llavavid.py
deleted
100644 → 0
View file @
c9064e6f
"""Launch the inference server for Llava-video model."""
import
argparse
from
sglang.srt.server
import
ServerArgs
,
launch_server
if
__name__
==
"__main__"
:
model_overide_args
=
{}
model_overide_args
[
"mm_spatial_pool_stride"
]
=
2
model_overide_args
[
"architectures"
]
=
[
"LlavaVidForCausalLM"
]
model_overide_args
[
"num_frames"
]
=
16
model_overide_args
[
"model_type"
]
=
"llavavid"
if
model_overide_args
[
"num_frames"
]
==
32
:
model_overide_args
[
"rope_scaling"
]
=
{
"factor"
:
2.0
,
"type"
:
"linear"
}
model_overide_args
[
"max_sequence_length"
]
=
4096
*
2
model_overide_args
[
"tokenizer_model_max_length"
]
=
4096
*
2
model_overide_args
[
"model_max_length"
]
=
4096
*
2
parser
=
argparse
.
ArgumentParser
()
ServerArgs
.
add_cli_args
(
parser
)
args
=
parser
.
parse_args
()
if
"34b"
in
args
.
model_path
.
lower
():
model_overide_args
[
"image_token_index"
]
=
64002
server_args
=
ServerArgs
.
from_cli_args
(
args
)
launch_server
(
server_args
,
model_overide_args
,
None
)
python/sglang/srt/layers/decode_attention.py
View file @
f6af3a65
...
@@ -26,7 +26,7 @@ import triton.language as tl
...
@@ -26,7 +26,7 @@ import triton.language as tl
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
global_server_args_dict
if
global_server_args_dict
.
get
(
"attention_reduce_in_fp32"
,
False
):
if
global_server_args_dict
.
get
(
"
triton_
attention_reduce_in_fp32"
,
False
):
REDUCE_TRITON_TYPE
=
tl
.
float32
REDUCE_TRITON_TYPE
=
tl
.
float32
REDUCE_TORCH_TYPE
=
torch
.
float32
REDUCE_TORCH_TYPE
=
torch
.
float32
else
:
else
:
...
...
python/sglang/srt/layers/fused_moe/layer.py
View file @
f6af3a65
...
@@ -239,7 +239,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -239,7 +239,7 @@ class FusedMoE(torch.nn.Module):
weight_name
:
str
,
weight_name
:
str
,
shard_id
:
int
,
shard_id
:
int
,
expert_id
:
int
,
expert_id
:
int
,
pre
_
sharded
:
bool
,
use_
presharded
_weights
:
bool
=
False
,
):
):
param_data
=
param
.
data
param_data
=
param
.
data
...
@@ -273,7 +273,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -273,7 +273,7 @@ class FusedMoE(torch.nn.Module):
else
:
else
:
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
intermediate_size_per_partition
shard_size
=
self
.
intermediate_size_per_partition
if
pre
_
sharded
:
if
use_
presharded
_weights
:
shard
=
slice
(
None
)
shard
=
slice
(
None
)
else
:
else
:
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
shard
=
slice
(
tp_rank
*
shard_size
,
(
tp_rank
+
1
)
*
shard_size
)
...
...
python/sglang/srt/layers/logits_processor.py
View file @
f6af3a65
...
@@ -180,7 +180,7 @@ class LogitsProcessor(nn.Module):
...
@@ -180,7 +180,7 @@ class LogitsProcessor(nn.Module):
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
last_logits
.
div_
(
self
.
config
.
final_logit_softcapping
)
last_logits
.
div_
(
self
.
config
.
final_logit_softcapping
)
last_logits
=
torch
.
tanh
(
last_logits
)
torch
.
tanh
(
last_logits
,
out
=
last_logits
)
last_logits
.
mul_
(
self
.
config
.
final_logit_softcapping
)
last_logits
.
mul_
(
self
.
config
.
final_logit_softcapping
)
# Return only last_logits if logprob is not requested
# Return only last_logits if logprob is not requested
...
@@ -241,7 +241,7 @@ class LogitsProcessor(nn.Module):
...
@@ -241,7 +241,7 @@ class LogitsProcessor(nn.Module):
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
if
hasattr
(
self
.
config
,
"final_logit_softcapping"
):
all_logits
.
div_
(
self
.
config
.
final_logit_softcapping
)
all_logits
.
div_
(
self
.
config
.
final_logit_softcapping
)
all_logits
=
torch
.
tanh
(
all_logits
)
torch
.
tanh
(
all_logits
,
out
=
all_logits
)
all_logits
.
mul_
(
self
.
config
.
final_logit_softcapping
)
all_logits
.
mul_
(
self
.
config
.
final_logit_softcapping
)
all_logprobs
=
all_logits
all_logprobs
=
all_logits
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
f6af3a65
...
@@ -35,7 +35,7 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
...
@@ -35,7 +35,7 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
global_server_args_dict
=
{
global_server_args_dict
=
{
"disable_flashinfer"
:
False
,
"disable_flashinfer"
:
False
,
"disable_flashinfer_sampling"
:
False
,
"disable_flashinfer_sampling"
:
False
,
"attention_reduce_in_fp32"
:
False
,
"
triton_
attention_reduce_in_fp32"
:
False
,
"enable_mla"
:
False
,
"enable_mla"
:
False
,
}
}
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
f6af3a65
...
@@ -606,6 +606,9 @@ class TokenizerManager:
...
@@ -606,6 +606,9 @@ class TokenizerManager:
return
background_tasks
return
background_tasks
def
create_handle_loop
(
self
):
def
create_handle_loop
(
self
):
if
not
self
.
to_create_loop
:
return
self
.
to_create_loop
=
False
self
.
to_create_loop
=
False
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
loop
.
create_task
(
self
.
handle_loop
())
loop
.
create_task
(
self
.
handle_loop
())
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
f6af3a65
...
@@ -20,7 +20,6 @@ import importlib
...
@@ -20,7 +20,6 @@ import importlib
import
importlib.resources
import
importlib.resources
import
logging
import
logging
import
pkgutil
import
pkgutil
import
warnings
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Optional
,
Type
from
typing
import
Optional
,
Type
...
@@ -91,23 +90,35 @@ class ModelRunner:
...
@@ -91,23 +90,35 @@ class ModelRunner:
{
{
"disable_flashinfer"
:
server_args
.
disable_flashinfer
,
"disable_flashinfer"
:
server_args
.
disable_flashinfer
,
"disable_flashinfer_sampling"
:
server_args
.
disable_flashinfer_sampling
,
"disable_flashinfer_sampling"
:
server_args
.
disable_flashinfer_sampling
,
"attention_reduce_in_fp32"
:
server_args
.
attention_reduce_in_fp32
,
"
triton_
attention_reduce_in_fp32"
:
server_args
.
triton_
attention_reduce_in_fp32
,
"enable_mla"
:
server_args
.
enable_mla
,
"enable_mla"
:
server_args
.
enable_mla
,
}
}
)
)
min_per_gpu_memory
=
self
.
init_torch_distributed
()
self
.
load_model
()
self
.
init_memory_pool
(
min_per_gpu_memory
,
server_args
.
max_num_reqs
,
server_args
.
max_total_tokens
,
)
self
.
init_cublas
()
self
.
init_flashinfer
()
self
.
init_cuda_graphs
()
def
init_torch_distributed
(
self
):
# Init torch distributed
# Init torch distributed
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
torch
.
cuda
.
set_device
(
self
.
gpu_id
)
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Init nccl begin."
)
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Init nccl begin."
)
if
not
server_args
.
enable_p2p_check
:
if
not
self
.
server_args
.
enable_p2p_check
:
monkey_patch_vllm_p2p_access_check
(
self
.
gpu_id
)
monkey_patch_vllm_p2p_access_check
(
self
.
gpu_id
)
if
server_args
.
nccl_init_addr
:
if
self
.
server_args
.
nccl_init_addr
:
nccl_init_method
=
f
"tcp://
{
server_args
.
nccl_init_addr
}
"
nccl_init_method
=
f
"tcp://
{
self
.
server_args
.
nccl_init_addr
}
"
else
:
else
:
nccl_init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl_port
}
"
nccl_init_method
=
f
"tcp://127.0.0.1:
{
self
.
nccl_port
}
"
set_custom_all_reduce
(
not
server_args
.
disable_custom_all_reduce
)
set_custom_all_reduce
(
not
self
.
server_args
.
disable_custom_all_reduce
)
init_distributed_environment
(
init_distributed_environment
(
backend
=
"nccl"
,
backend
=
"nccl"
,
world_size
=
self
.
tp_size
,
world_size
=
self
.
tp_size
,
...
@@ -116,32 +127,28 @@ class ModelRunner:
...
@@ -116,32 +127,28 @@ class ModelRunner:
distributed_init_method
=
nccl_init_method
,
distributed_init_method
=
nccl_init_method
,
)
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
initialize_model_parallel
(
tensor_model_parallel_size
=
self
.
tp_size
)
total
_gpu_memory
=
get_available_gpu_memory
(
min_per
_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
)
self
.
tp_group
=
get_tp_group
()
self
.
tp_group
=
get_tp_group
()
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
# so we disable padding in cuda graph.
if
not
all
(
in_the_same_node_as
(
self
.
tp_group
.
cpu_group
,
source_rank
=
0
)):
self
.
server_args
.
disable_cuda_graph_padding
=
True
logger
.
info
(
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
)
# Check memory for tensor parallelism
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
total_
local_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
)
local_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
)
if
total_local
_gpu_memory
<
tot
al_gpu_memory
*
0.9
:
if
min_per
_gpu_memory
<
loc
al_gpu_memory
*
0.9
:
raise
ValueError
(
raise
ValueError
(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
)
# Load the model and create memory pool
return
min_per_gpu_memory
self
.
load_model
()
self
.
init_memory_pool
(
total_gpu_memory
,
server_args
.
max_num_reqs
,
server_args
.
max_total_tokens
,
)
self
.
init_cublas
()
self
.
init_flashinfer
()
if
self
.
is_generation
:
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
# Capture cuda graphs
self
.
init_cuda_graphs
()
def
load_model
(
self
):
def
load_model
(
self
):
logger
.
info
(
logger
.
info
(
...
@@ -150,7 +157,7 @@ class ModelRunner:
...
@@ -150,7 +157,7 @@ class ModelRunner:
)
)
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
if
torch
.
cuda
.
get_device_capability
()[
0
]
<
8
:
logger
.
info
(
logger
.
info
(
"Compute capability below sm80
u
se float16 due to lack of bfloat16 support."
"Compute capability below sm80
. U
se float16 due to lack of bfloat16 support."
)
)
self
.
server_args
.
dtype
=
"float16"
self
.
server_args
.
dtype
=
"float16"
...
@@ -168,8 +175,9 @@ class ModelRunner:
...
@@ -168,8 +175,9 @@ class ModelRunner:
skip_tokenizer_init
=
True
,
skip_tokenizer_init
=
True
,
)
)
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
# Drop this after Sept, 2024.
if
is_llama3_405b_fp8_head_16
(
self
.
model_config
)
and
self
.
tp_size
<=
8
:
if
is_llama3_405b_fp8_head_16
(
self
.
model_config
)
and
self
.
tp_size
<=
8
:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self
.
model_config
.
hf_config
.
num_key_value_heads
=
8
self
.
model_config
.
hf_config
.
num_key_value_heads
=
8
self
.
vllm_model_config
.
hf_config
.
num_key_value_heads
=
8
self
.
vllm_model_config
.
hf_config
.
num_key_value_heads
=
8
monkey_patch_vllm_qvk_linear_loader
()
monkey_patch_vllm_qvk_linear_loader
()
...
@@ -191,8 +199,8 @@ class ModelRunner:
...
@@ -191,8 +199,8 @@ class ModelRunner:
cache_config
=
None
,
cache_config
=
None
,
)
)
self
.
sliding_window_size
=
(
self
.
sliding_window_size
=
(
self
.
model
.
get_window_size
()
self
.
model
.
get_
attention_sliding_
window_size
()
if
hasattr
(
self
.
model
,
"get_window_size"
)
if
hasattr
(
self
.
model
,
"get_
attention_sliding_
window_size"
)
else
None
else
None
)
)
self
.
is_generation
=
is_generation_model
(
self
.
is_generation
=
is_generation_model
(
...
@@ -206,7 +214,8 @@ class ModelRunner:
...
@@ -206,7 +214,8 @@ class ModelRunner:
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
f
"avail mem=
{
get_available_gpu_memory
(
self
.
gpu_id
):.
2
f
}
GB"
)
)
def
update_weights
(
self
,
model_path
,
load_format
):
def
update_weights
(
self
,
model_path
:
str
,
load_format
:
str
):
"""Update weights in-place."""
from
vllm.model_executor.model_loader.loader
import
(
from
vllm.model_executor.model_loader.loader
import
(
DefaultModelLoader
,
DefaultModelLoader
,
device_loading_context
,
device_loading_context
,
...
@@ -222,6 +231,7 @@ class ModelRunner:
...
@@ -222,6 +231,7 @@ class ModelRunner:
target_device
=
torch
.
device
(
self
.
device_config
.
device
)
target_device
=
torch
.
device
(
self
.
device_config
.
device
)
try
:
try
:
# TODO: Use a better method to check this
vllm_model_config
=
VllmModelConfig
(
vllm_model_config
=
VllmModelConfig
(
model
=
model_path
,
model
=
model_path
,
quantization
=
self
.
server_args
.
quantization
,
quantization
=
self
.
server_args
.
quantization
,
...
@@ -291,7 +301,7 @@ class ModelRunner:
...
@@ -291,7 +301,7 @@ class ModelRunner:
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Update weights end."
)
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Update weights end."
)
return
True
,
"Succeeded to update model weights"
return
True
,
"Succeeded to update model weights"
def
profile_max_num_token
(
self
,
total_gpu_memory
):
def
profile_max_num_token
(
self
,
total_gpu_memory
:
int
):
available_gpu_memory
=
get_available_gpu_memory
(
available_gpu_memory
=
get_available_gpu_memory
(
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
self
.
gpu_id
,
distributed
=
self
.
tp_size
>
1
)
)
...
@@ -319,7 +329,10 @@ class ModelRunner:
...
@@ -319,7 +329,10 @@ class ModelRunner:
return
max_num_token
return
max_num_token
def
init_memory_pool
(
def
init_memory_pool
(
self
,
total_gpu_memory
,
max_num_reqs
=
None
,
max_total_tokens
=
None
self
,
total_gpu_memory
:
int
,
max_num_reqs
:
int
=
None
,
max_total_tokens
:
int
=
None
,
):
):
self
.
max_total_num_tokens
=
self
.
profile_max_num_token
(
total_gpu_memory
)
self
.
max_total_num_tokens
=
self
.
profile_max_num_token
(
total_gpu_memory
)
if
max_total_tokens
is
not
None
:
if
max_total_tokens
is
not
None
:
...
@@ -388,6 +401,7 @@ class ModelRunner:
...
@@ -388,6 +401,7 @@ class ModelRunner:
return
c
return
c
def
init_flashinfer
(
self
):
def
init_flashinfer
(
self
):
"""Init flashinfer attention kernel wrappers."""
if
self
.
server_args
.
disable_flashinfer
:
if
self
.
server_args
.
disable_flashinfer
:
assert
(
assert
(
self
.
sliding_window_size
is
None
self
.
sliding_window_size
is
None
...
@@ -448,6 +462,11 @@ class ModelRunner:
...
@@ -448,6 +462,11 @@ class ModelRunner:
)
)
def
init_cuda_graphs
(
self
):
def
init_cuda_graphs
(
self
):
"""Capture cuda graphs."""
if
not
self
.
is_generation
:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
return
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
from
sglang.srt.model_executor.cuda_graph_runner
import
CudaGraphRunner
if
self
.
server_args
.
disable_cuda_graph
or
self
.
server_args
.
disable_flashinfer
:
if
self
.
server_args
.
disable_cuda_graph
or
self
.
server_args
.
disable_flashinfer
:
...
@@ -457,7 +476,12 @@ class ModelRunner:
...
@@ -457,7 +476,12 @@ class ModelRunner:
logger
.
info
(
logger
.
info
(
f
"[gpu=
{
self
.
gpu_id
}
] Capture cuda graph begin. This can take up to several minutes."
f
"[gpu=
{
self
.
gpu_id
}
] Capture cuda graph begin. This can take up to several minutes."
)
)
batch_size_list
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
17
)]
if
self
.
server_args
.
disable_cuda_graph_padding
:
batch_size_list
=
list
(
range
(
1
,
32
))
+
[
64
,
128
]
else
:
batch_size_list
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
21
)]
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
.
cuda_graph_runner
=
CudaGraphRunner
(
self
,
self
,
max_batch_size_to_capture
=
max
(
batch_size_list
),
max_batch_size_to_capture
=
max
(
batch_size_list
),
...
...
python/sglang/srt/models/gemma2.py
View file @
f6af3a65
...
@@ -46,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import InputMetadata
...
@@ -46,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import InputMetadata
# Aligned with HF's implementation, using sliding window inclusive with the last token
# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
# SGLang assumes exclusive
def
get_window_size
(
config
):
def
get_
attention_sliding_
window_size
(
config
):
return
config
.
sliding_window
-
1
return
config
.
sliding_window
-
1
...
@@ -213,7 +213,11 @@ class Gemma2Attention(nn.Module):
...
@@ -213,7 +213,11 @@ class Gemma2Attention(nn.Module):
self
.
scaling
,
self
.
scaling
,
num_kv_heads
=
self
.
num_kv_heads
,
num_kv_heads
=
self
.
num_kv_heads
,
layer_id
=
layer_idx
,
layer_id
=
layer_idx
,
sliding_window_size
=
get_window_size
(
config
)
if
use_sliding_window
else
None
,
sliding_window_size
=
(
get_attention_sliding_window_size
(
config
)
if
use_sliding_window
else
None
),
logit_cap
=
self
.
config
.
attn_logit_softcapping
,
logit_cap
=
self
.
config
.
attn_logit_softcapping
,
)
)
...
@@ -406,8 +410,8 @@ class Gemma2ForCausalLM(nn.Module):
...
@@ -406,8 +410,8 @@ class Gemma2ForCausalLM(nn.Module):
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
input_ids
,
hidden_states
,
self
.
model
.
embed_tokens
.
weight
,
input_metadata
)
)
def
get_window_size
(
self
):
def
get_
attention_sliding_
window_size
(
self
):
return
get_window_size
(
self
.
config
)
return
get_
attention_sliding_
window_size
(
self
.
config
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]]):
stacked_params_mapping
=
[
stacked_params_mapping
=
[
...
...
python/sglang/srt/models/grok.py
View file @
f6af3a65
...
@@ -295,12 +295,14 @@ class Grok1ModelForCausalLM(nn.Module):
...
@@ -295,12 +295,14 @@ class Grok1ModelForCausalLM(nn.Module):
self
.
config
=
config
self
.
config
=
config
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
model
=
Grok1Model
(
config
,
quant_config
=
quant_config
)
self
.
model
=
Grok1Model
(
config
,
quant_config
=
quant_config
)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self
.
lm_head
=
ParallelLMHead
(
config
.
vocab_size
,
config
.
hidden_size
)
self
.
lm_head
=
ReplicatedLinear
(
config
.
hidden_size
,
config
.
vocab_size
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
,
skip_all_gather
=
True
)
# Monkey patch _prepare_weights to load pre-sharded weights
# Monkey patch _prepare_weights to load pre-sharded weights
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
self
.
use_presharded_weights
=
True
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
warnings
.
filterwarnings
(
"ignore"
,
category
=
FutureWarning
)
def
forward
(
def
forward
(
...
@@ -356,6 +358,13 @@ class Grok1ModelForCausalLM(nn.Module):
...
@@ -356,6 +358,13 @@ class Grok1ModelForCausalLM(nn.Module):
continue
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
name
=
name
.
replace
(
weight_name
,
param_name
)
if
self
.
use_presharded_weights
:
extra_kwargs
=
{
"use_presharded_weights"
:
self
.
use_presharded_weights
}
else
:
extra_kwargs
=
{}
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
=
param
.
weight_loader
weight_loader
(
weight_loader
(
...
@@ -364,7 +373,7 @@ class Grok1ModelForCausalLM(nn.Module):
...
@@ -364,7 +373,7 @@ class Grok1ModelForCausalLM(nn.Module):
weight_name
,
weight_name
,
shard_id
=
shard_id
,
shard_id
=
shard_id
,
expert_id
=
expert_id
,
expert_id
=
expert_id
,
pre_sharded
=
get_tensor_model_parallel_world_size
()
>
1
,
**
extra_kwargs
,
)
)
break
break
else
:
else
:
...
...
python/sglang/srt/server_args.py
View file @
f6af3a65
...
@@ -81,13 +81,12 @@ class ServerArgs:
...
@@ -81,13 +81,12 @@ class ServerArgs:
disable_cuda_graph
:
bool
=
False
disable_cuda_graph
:
bool
=
False
disable_cuda_graph_padding
:
bool
=
False
disable_cuda_graph_padding
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_disk_cache
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_mixed_chunk
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_torch_compile
:
bool
=
False
enable_p2p_check
:
bool
=
False
enable_p2p_check
:
bool
=
False
enable_mla
:
bool
=
False
enable_mla
:
bool
=
False
attention_reduce_in_fp32
:
bool
=
False
triton_attention_reduce_in_fp32
:
bool
=
False
efficient_weight_load
:
bool
=
False
disable_custom_all_reduce
:
bool
=
False
# Distributed args
# Distributed args
nccl_init_addr
:
Optional
[
str
]
=
None
nccl_init_addr
:
Optional
[
str
]
=
None
...
@@ -404,6 +403,12 @@ class ServerArgs:
...
@@ -404,6 +403,12 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Disable disk cache to avoid possible crashes related to file system or high concurrency."
,
help
=
"Disable disk cache to avoid possible crashes related to file system or high concurrency."
,
)
)
parser
.
add_argument
(
"--disable-custom-all-reduce"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Disable the custom all-reduce kernel and fall back to NCCL."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--enable-mixed-chunk"
,
"--enable-mixed-chunk"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
@@ -425,7 +430,7 @@ class ServerArgs:
...
@@ -425,7 +430,7 @@ class ServerArgs:
help
=
"Enable Multi-head Latent Attention (MLA) for DeepSeek-V2."
,
help
=
"Enable Multi-head Latent Attention (MLA) for DeepSeek-V2."
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--attention-reduce-in-fp32"
,
"--
triton-
attention-reduce-in-fp32"
,
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
help
=
"Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels."
,
"This only affects Triton attention kernels."
,
...
@@ -435,12 +440,6 @@ class ServerArgs:
...
@@ -435,12 +440,6 @@ class ServerArgs:
action
=
"store_true"
,
action
=
"store_true"
,
help
=
"Turn on memory efficient weight loading with quantization (quantize per layer during loading)."
,
help
=
"Turn on memory efficient weight loading with quantization (quantize per layer during loading)."
,
)
)
parser
.
add_argument
(
"--disable-custom-all-reduce"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Disable the custom all-reduce kernel and fall back to NCCL."
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
python/sglang/srt/utils.py
View file @
f6af3a65
...
@@ -347,7 +347,7 @@ def suppress_other_loggers():
...
@@ -347,7 +347,7 @@ def suppress_other_loggers():
logging
.
WARN
logging
.
WARN
)
)
logging
.
getLogger
(
"vllm.selector"
).
setLevel
(
logging
.
WARN
)
logging
.
getLogger
(
"vllm.selector"
).
setLevel
(
logging
.
WARN
)
logging
.
getLogger
(
"vllm.utils"
).
setLevel
(
logging
.
WARN
)
logging
.
getLogger
(
"vllm.utils"
).
setLevel
(
logging
.
ERROR
)
def
assert_pkg_version
(
pkg
:
str
,
min_version
:
str
,
message
:
str
):
def
assert_pkg_version
(
pkg
:
str
,
min_version
:
str
,
message
:
str
):
...
@@ -451,10 +451,6 @@ def monkey_patch_vllm_dummy_weight_loader():
...
@@ -451,10 +451,6 @@ def monkey_patch_vllm_dummy_weight_loader():
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
quant_method
=
getattr
(
module
,
"quant_method"
,
None
)
if
quant_method
is
not
None
:
if
quant_method
is
not
None
:
quant_method
.
process_weights_after_loading
(
module
)
quant_method
.
process_weights_after_loading
(
module
)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if
hasattr
(
module
,
"process_weights_after_loading"
):
module
.
process_weights_after_loading
()
# NOTE(woosuk): For accurate performance evaluation, we assign
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
# random values to the weights.
...
...
python/sglang/test/runners.py
View file @
f6af3a65
...
@@ -24,7 +24,6 @@ import torch.nn.functional as F
...
@@ -24,7 +24,6 @@ import torch.nn.functional as F
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
sglang.srt.server
import
Runtime
from
sglang.srt.server
import
Runtime
from
sglang.srt.utils
import
is_generation_model
DEFAULT_PROMPTS
=
[
DEFAULT_PROMPTS
=
[
# the output of gemma-2-2b from SRT is unstable on the commented prompt
# the output of gemma-2-2b from SRT is unstable on the commented prompt
...
@@ -63,8 +62,8 @@ class HFRunner:
...
@@ -63,8 +62,8 @@ class HFRunner:
def
__init__
(
def
__init__
(
self
,
self
,
model_path
,
model_path
,
torch_dtype
=
torch
.
float16
,
torch_dtype
,
is_generation_model
=
None
,
is_generation_model
,
):
):
self
.
in_queue
=
multiprocessing
.
Queue
()
self
.
in_queue
=
multiprocessing
.
Queue
()
self
.
out_queue
=
multiprocessing
.
Queue
()
self
.
out_queue
=
multiprocessing
.
Queue
()
...
@@ -90,11 +89,8 @@ class HFRunner:
...
@@ -90,11 +89,8 @@ class HFRunner:
trust_remote_code
=
True
,
trust_remote_code
=
True
,
)
)
self
.
is_generation_model
=
(
self
.
is_generation_model
=
is_generation_model
is_generation_model
(
model_path
)
if
is_generation_model
is
None
else
is_generation_model
)
if
self
.
is_generation_model
:
if
self
.
is_generation_model
:
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
model_path
,
...
@@ -176,16 +172,12 @@ class SRTRunner:
...
@@ -176,16 +172,12 @@ class SRTRunner:
def
__init__
(
def
__init__
(
self
,
self
,
model_path
,
model_path
,
torch_dtype
,
is_generation_model
,
tp_size
=
1
,
tp_size
=
1
,
torch_dtype
=
torch
.
float16
,
is_generation_model
=
None
,
port
=
5157
,
port
=
5157
,
):
):
self
.
is_generation_model
=
(
self
.
is_generation_model
=
is_generation_model
is_generation_model
(
model_path
)
if
is_generation_model
is
None
else
is_generation_model
)
self
.
runtime
=
Runtime
(
self
.
runtime
=
Runtime
(
model_path
=
model_path
,
model_path
=
model_path
,
tp_size
=
tp_size
,
tp_size
=
tp_size
,
...
...
scripts/convert_yi_vl.py
→
scripts/
deprecated/
convert_yi_vl.py
View file @
f6af3a65
File moved
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment