Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bdc1acf6
Unverified
Commit
bdc1acf6
authored
Jan 07, 2025
by
Lianmin Zheng
Committed by
GitHub
Jan 07, 2025
Browse files
Misc fix for min_p_sampling, --cuda-graph-bs (#2761)
parent
6d08ce2a
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
135 additions
and
63 deletions
+135
-63
.github/workflows/pr-test.yml
.github/workflows/pr-test.yml
+3
-1
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
...hmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
+1
-0
python/pyproject.toml
python/pyproject.toml
+9
-3
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+4
-1
python/sglang/srt/layers/logits_processor.py
python/sglang/srt/layers/logits_processor.py
+5
-0
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+16
-5
python/sglang/srt/layers/quantization/__init__.py
python/sglang/srt/layers/quantization/__init__.py
+1
-2
python/sglang/srt/managers/data_parallel_controller.py
python/sglang/srt/managers/data_parallel_controller.py
+2
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-2
python/sglang/srt/metrics/collector.py
python/sglang/srt/metrics/collector.py
+22
-30
python/sglang/srt/model_executor/cuda_graph_runner.py
python/sglang/srt/model_executor/cuda_graph_runner.py
+8
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+8
-1
python/sglang/srt/sampling/sampling_batch_info.py
python/sglang/srt/sampling/sampling_batch_info.py
+1
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+4
-6
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+7
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+6
-5
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+35
-6
No files found.
.github/workflows/pr-test.yml
View file @
bdc1acf6
...
@@ -66,12 +66,14 @@ jobs:
...
@@ -66,12 +66,14 @@ jobs:
-
name
:
Run test
-
name
:
Run test
timeout-minutes
:
25
timeout-minutes
:
25
run
:
|
run
:
|
cd test/srt
RANGE=${{ matrix.range }}
RANGE=${{ matrix.range }}
range_begin=${RANGE%-*}
range_begin=${RANGE%-*}
range_end=${RANGE#*-}
range_end=${RANGE#*-}
cd test/srt
python3 run_suite.py --suite per-commit --range-begin ${range_begin} --range-end ${range_end}
python3 run_suite.py --suite per-commit --range-begin ${range_begin} --range-end ${range_end}
unit-test-backend-2-gpu
:
unit-test-backend-2-gpu
:
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
if
:
github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on
:
2-gpu-runner
runs-on
:
2-gpu-runner
...
...
benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py
View file @
bdc1acf6
...
@@ -228,6 +228,7 @@ class BenchmarkWorker:
...
@@ -228,6 +228,7 @@ class BenchmarkWorker:
hidden_size
,
hidden_size
,
topk
,
topk
,
dtype_str
,
dtype_str
,
False
,
)
)
else
:
else
:
config
=
op_config
[
min
(
op_config
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
config
=
op_config
[
min
(
op_config
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
num_tokens
))]
...
...
python/pyproject.toml
View file @
bdc1acf6
...
@@ -16,14 +16,20 @@ classifiers = [
...
@@ -16,14 +16,20 @@ classifiers = [
dependencies
=
[
"requests"
,
"tqdm"
,
"numpy"
,
"IPython"
,
"setproctitle"
]
dependencies
=
[
"requests"
,
"tqdm"
,
"numpy"
,
"IPython"
,
"setproctitle"
]
[project.optional-dependencies]
[project.optional-dependencies]
runtime_common
=
[
"aiohttp"
,
"decord"
,
"fastapi"
,
runtime_common
=
[
"aiohttp"
,
"decord"
,
"fastapi"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
"modelscope"
,
"hf_transfer"
,
"huggingface_hub"
,
"interegular"
,
"modelscope"
,
"orjson"
,
"outlines>=0.0.44,<0.1.0"
,
"orjson"
,
"outlines>=0.0.44,<0.1.0"
,
"packaging"
,
"pillow"
,
"prometheus-client>=0.20.0"
,
"packaging"
,
"pillow"
,
"prometheus-client>=0.20.0"
,
"psutil"
,
"pydantic"
,
"python-multipart"
,
"psutil"
,
"pydantic"
,
"python-multipart"
,
"pyzmq>=25.1.2"
,
"torchao>=0.7.0"
,
"uvicorn"
,
"uvloop"
,
"pyzmq>=25.1.2"
,
"torchao>=0.7.0"
,
"uvicorn"
,
"uvloop"
,
"xgrammar>=0.1.6"
]
"xgrammar>=0.1.6"
srt
=
["sglang[runtime_common]
", "
torch
", "
vllm>=
0.6.3
.post
1
,
<=
0.6.4
.post
1
", "
cuda-python
", "
flashinfer==
0.1.6
", "
sgl-kernel>=
0.0.2
.post
11
"]
]
srt
=
[
"sglang[runtime_common]"
,
"cuda-python"
,
"sgl-kernel>=0.0.2.post11"
,
"torch"
,
"vllm>=0.6.3.post1,<=0.6.4.post1"
,
"flashinfer==0.1.6"
]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
...
...
python/sglang/bench_serving.py
View file @
bdc1acf6
...
@@ -563,7 +563,7 @@ def sample_sharegpt_requests(
...
@@ -563,7 +563,7 @@ def sample_sharegpt_requests(
raise
ValueError
(
"output_len too small"
)
raise
ValueError
(
"output_len too small"
)
# Download sharegpt if necessary
# Download sharegpt if necessary
if
not
os
.
path
.
isfile
(
dataset_path
):
if
not
os
.
path
.
isfile
(
dataset_path
)
and
dataset_path
==
""
:
dataset_path
=
download_and_cache_file
(
SHAREGPT_URL
)
dataset_path
=
download_and_cache_file
(
SHAREGPT_URL
)
# Load the dataset.
# Load the dataset.
...
@@ -1064,8 +1064,11 @@ async def benchmark(
...
@@ -1064,8 +1064,11 @@ async def benchmark(
"total_output_tokens_retokenized"
:
metrics
.
total_output_retokenized
,
"total_output_tokens_retokenized"
:
metrics
.
total_output_retokenized
,
"mean_e2e_latency_ms"
:
metrics
.
mean_e2e_latency_ms
,
"mean_e2e_latency_ms"
:
metrics
.
mean_e2e_latency_ms
,
"median_e2e_latency_ms"
:
metrics
.
median_e2e_latency_ms
,
"median_e2e_latency_ms"
:
metrics
.
median_e2e_latency_ms
,
"mean_ttft_ms"
:
metrics
.
mean_ttft_ms
,
"median_ttft_ms"
:
metrics
.
median_ttft_ms
,
"median_ttft_ms"
:
metrics
.
median_ttft_ms
,
"mean_itl_ms"
:
metrics
.
mean_itl_ms
,
"median_itl_ms"
:
metrics
.
median_itl_ms
,
"median_itl_ms"
:
metrics
.
median_itl_ms
,
"input_throughput"
:
metrics
.
input_throughput
,
"output_throughput"
:
metrics
.
output_throughput
,
"output_throughput"
:
metrics
.
output_throughput
,
"sharegpt_output_len"
:
args
.
sharegpt_output_len
,
"sharegpt_output_len"
:
args
.
sharegpt_output_len
,
"random_input_len"
:
args
.
random_input_len
,
"random_input_len"
:
args
.
random_input_len
,
...
...
python/sglang/srt/layers/logits_processor.py
View file @
bdc1acf6
...
@@ -117,6 +117,11 @@ class LogitsProcessor(nn.Module):
...
@@ -117,6 +117,11 @@ class LogitsProcessor(nn.Module):
self
.
final_logit_softcapping
=
getattr
(
self
.
final_logit_softcapping
=
getattr
(
self
.
config
,
"final_logit_softcapping"
,
None
self
.
config
,
"final_logit_softcapping"
,
None
)
)
if
(
self
.
final_logit_softcapping
is
not
None
and
self
.
final_logit_softcapping
<
0
):
self
.
final_logit_softcapping
=
None
def
forward
(
def
forward
(
self
,
self
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
bdc1acf6
...
@@ -1011,6 +1011,17 @@ def fused_experts_impl(
...
@@ -1011,6 +1011,17 @@ def fused_experts_impl(
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
)
)
else
:
else
:
if
topk_ids
.
shape
[
1
]
==
1
:
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
].
copy_
(
intermediate_cache3
[:,
0
]
)
elif
topk_ids
.
shape
[
1
]
==
2
:
torch
.
add
(
intermediate_cache3
[:,
0
],
intermediate_cache3
[:,
1
],
out
=
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
).
squeeze
(
dim
=
1
)
elif
topk_ids
.
shape
[
1
]
>
2
:
torch
.
sum
(
torch
.
sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
dim
=
1
,
dim
=
1
,
...
...
python/sglang/srt/layers/quantization/__init__.py
View file @
bdc1acf6
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
from
typing
import
Callable
,
Dict
,
Optional
,
Type
from
typing
import
Dict
,
Type
import
torch
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMarlinConfig
from
vllm.model_executor.layers.quantization.awq_marlin
import
AWQMarlinConfig
...
...
python/sglang/srt/managers/data_parallel_controller.py
View file @
bdc1acf6
...
@@ -20,6 +20,7 @@ import threading
...
@@ -20,6 +20,7 @@ import threading
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
import
psutil
import
psutil
import
setproctitle
import
zmq
import
zmq
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
...
@@ -230,6 +231,7 @@ def run_data_parallel_controller_process(
...
@@ -230,6 +231,7 @@ def run_data_parallel_controller_process(
port_args
:
PortArgs
,
port_args
:
PortArgs
,
pipe_writer
,
pipe_writer
,
):
):
setproctitle
.
setproctitle
(
"sglang::data_parallel_controller"
)
configure_logger
(
server_args
)
configure_logger
(
server_args
)
parent_process
=
psutil
.
Process
().
parent
()
parent_process
=
psutil
.
Process
().
parent
()
...
...
python/sglang/srt/managers/scheduler.py
View file @
bdc1acf6
...
@@ -1516,8 +1516,9 @@ class Scheduler:
...
@@ -1516,8 +1516,9 @@ class Scheduler:
return
success
,
message
return
success
,
message
def
update_weights_from_distributed
(
def
update_weights_from_distributed
(
self
,
recv_req
:
UpdateWeightsFromDistributedReqInput
self
,
):
recv_req
:
UpdateWeightsFromDistributedReqInput
,
)
->
Tuple
[
bool
,
str
]:
"""Update the online model parameter."""
"""Update the online model parameter."""
success
,
message
=
self
.
tp_worker
.
update_weights_from_distributed
(
recv_req
)
success
,
message
=
self
.
tp_worker
.
update_weights_from_distributed
(
recv_req
)
if
success
:
if
success
:
...
...
python/sglang/srt/metrics/collector.py
View file @
bdc1acf6
...
@@ -114,26 +114,20 @@ class TokenizerMetricsCollector:
...
@@ -114,26 +114,20 @@ class TokenizerMetricsCollector:
documentation
=
"Histogram of time to first token in seconds."
,
documentation
=
"Histogram of time to first token in seconds."
,
labelnames
=
labels
.
keys
(),
labelnames
=
labels
.
keys
(),
buckets
=
[
buckets
=
[
0.001
,
0.005
,
0.01
,
0.02
,
0.04
,
0.06
,
0.08
,
0.1
,
0.1
,
0.25
,
0.25
,
0.5
,
0.5
,
0.75
,
0.75
,
1.0
,
1
,
2.5
,
2
,
5.0
,
5
,
7.5
,
10
,
10.0
,
20
,
15.0
,
40
,
20.0
,
60
,
25.0
,
80
,
30.0
,
120
,
160
,
],
],
)
)
...
@@ -168,21 +162,19 @@ class TokenizerMetricsCollector:
...
@@ -168,21 +162,19 @@ class TokenizerMetricsCollector:
documentation
=
"Histogram of End-to-end request latency in seconds"
,
documentation
=
"Histogram of End-to-end request latency in seconds"
,
labelnames
=
labels
.
keys
(),
labelnames
=
labels
.
keys
(),
buckets
=
[
buckets
=
[
0.3
,
0.1
,
0.25
,
0.5
,
0.5
,
0.8
,
1
,
1.0
,
2
,
1.5
,
5
,
2.0
,
10
,
2.5
,
20
,
5.0
,
40
,
10.0
,
60
,
15.0
,
80
,
20.0
,
120
,
30.0
,
160
,
40.0
,
50.0
,
60.0
,
],
],
)
)
...
...
python/sglang/srt/model_executor/cuda_graph_runner.py
View file @
bdc1acf6
...
@@ -124,6 +124,13 @@ class CudaGraphRunner:
...
@@ -124,6 +124,13 @@ class CudaGraphRunner:
self
.
tp_size
=
self
.
model_runner
.
tp_size
self
.
tp_size
=
self
.
model_runner
.
tp_size
# Batch sizes to capture
# Batch sizes to capture
self
.
capture_bs
=
self
.
model_runner
.
server_args
.
cuda_graph_bs
if
self
.
capture_bs
is
None
:
if
model_runner
.
server_args
.
disable_cuda_graph_padding
:
self
.
capture_bs
=
list
(
range
(
1
,
33
))
+
[
64
,
128
]
else
:
self
.
capture_bs
=
[
1
,
2
,
4
]
+
[
i
*
8
for
i
in
range
(
1
,
21
)]
if
model_runner
.
server_args
.
disable_cuda_graph_padding
:
if
model_runner
.
server_args
.
disable_cuda_graph_padding
:
self
.
capture_bs
=
list
(
range
(
1
,
33
))
+
[
64
,
128
]
self
.
capture_bs
=
list
(
range
(
1
,
33
))
+
[
64
,
128
]
else
:
else
:
...
@@ -340,8 +347,8 @@ class CudaGraphRunner:
...
@@ -340,8 +347,8 @@ class CudaGraphRunner:
top_logprobs_nums
=
[
0
]
*
bs
,
top_logprobs_nums
=
[
0
]
*
bs
,
positions
=
positions
,
positions
=
positions
,
global_num_tokens
=
global_num_tokens
,
global_num_tokens
=
global_num_tokens
,
mrope_positions
=
mrope_positions
,
gathered_buffer
=
gathered_buffer
,
gathered_buffer
=
gathered_buffer
,
mrope_positions
=
mrope_positions
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_algorithm
=
self
.
model_runner
.
spec_algorithm
,
spec_info
=
spec_info
,
spec_info
=
spec_info
,
capture_hidden_mode
=
(
capture_hidden_mode
=
(
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
bdc1acf6
...
@@ -89,6 +89,7 @@ class ModelRunner:
...
@@ -89,6 +89,7 @@ class ModelRunner:
self
.
is_draft_worker
=
is_draft_worker
self
.
is_draft_worker
=
is_draft_worker
self
.
is_generation
=
model_config
.
is_generation
self
.
is_generation
=
model_config
.
is_generation
self
.
is_multimodal
=
model_config
.
is_multimodal
self
.
is_multimodal
=
model_config
.
is_multimodal
self
.
should_log
=
tp_rank
==
0
self
.
spec_algorithm
=
SpeculativeAlgorithm
.
from_string
(
self
.
spec_algorithm
=
SpeculativeAlgorithm
.
from_string
(
server_args
.
speculative_algorithm
server_args
.
speculative_algorithm
)
)
...
@@ -117,15 +118,21 @@ class ModelRunner:
...
@@ -117,15 +118,21 @@ class ModelRunner:
if
self
.
is_multimodal
:
if
self
.
is_multimodal
:
self
.
mem_fraction_static
*=
0.95
self
.
mem_fraction_static
*=
0.95
logger
.
info
(
f
"Automatically reduce --mem-fraction-static to
{
self
.
mem_fraction_static
:.
3
f
}
"
f
"because this is a multimodal model."
)
if
self
.
model_config
.
hf_config
.
architectures
==
[
if
self
.
model_config
.
hf_config
.
architectures
==
[
"MllamaForConditionalGeneration"
"MllamaForConditionalGeneration"
]:
]:
logger
.
info
(
"Automatically turn off --chunked-prefill-size for mllama."
)
logger
.
info
(
"Automatically turn off --chunked-prefill-size for mllama."
)
server_args
.
chunked_prefill_size
=
-
1
server_args
.
chunked_prefill_size
=
-
1
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
if
self
.
model_config
.
hf_config
.
architectures
==
[
if
self
.
model_config
.
hf_config
.
architectures
==
[
"Qwen2VLForConditionalGeneration"
"Qwen2VLForConditionalGeneration"
]:
]:
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
logger
.
info
(
logger
.
info
(
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
)
)
...
...
python/sglang/srt/sampling/sampling_batch_info.py
View file @
bdc1acf6
...
@@ -232,6 +232,7 @@ class SamplingBatchInfo:
...
@@ -232,6 +232,7 @@ class SamplingBatchInfo:
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
=
SamplingBatchInfo
.
merge_bias_tensor
(
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
),
self
.
device
self
.
logit_bias
,
other
.
logit_bias
,
len
(
self
),
len
(
other
),
self
.
device
)
)
self
.
need_min_p_sampling
=
self
.
need_min_p_sampling
or
other
.
need_min_p_sampling
def
apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
):
def
apply_logits_bias
(
self
,
logits
:
torch
.
Tensor
):
# Apply logit_bias
# Apply logit_bias
...
...
python/sglang/srt/server.py
View file @
bdc1acf6
...
@@ -127,14 +127,12 @@ async def health() -> Response:
...
@@ -127,14 +127,12 @@ async def health() -> Response:
async
def
health_generate
(
request
:
Request
)
->
Response
:
async
def
health_generate
(
request
:
Request
)
->
Response
:
"""Check the health of the inference server by generating one token."""
"""Check the health of the inference server by generating one token."""
sampling_params
=
{
"max_new_tokens"
:
1
,
"temperature"
:
0.7
}
if
tokenizer_manager
.
is_generation
:
if
tokenizer_manager
.
is_generation
:
gri
=
GenerateReqInput
(
gri
=
GenerateReqInput
(
input_ids
=
[
0
],
sampling_params
=
sampling_params
)
input_ids
=
[
0
],
sampling_params
=
{
"max_new_tokens"
:
1
,
"temperature"
:
0.7
}
)
else
:
else
:
gri
=
EmbeddingReqInput
(
gri
=
EmbeddingReqInput
(
input_ids
=
[
0
],
sampling_params
=
sampling_params
)
input_ids
=
[
0
],
sampling_params
=
{
"max_new_tokens"
:
1
,
"temperature"
:
0.7
}
)
try
:
try
:
async
for
_
in
tokenizer_manager
.
generate_request
(
gri
,
request
):
async
for
_
in
tokenizer_manager
.
generate_request
(
gri
,
request
):
...
...
python/sglang/srt/server_args.py
View file @
bdc1acf6
...
@@ -148,6 +148,7 @@ class ServerArgs:
...
@@ -148,6 +148,7 @@ class ServerArgs:
enable_torch_compile
:
bool
=
False
enable_torch_compile
:
bool
=
False
torch_compile_max_bs
:
int
=
32
torch_compile_max_bs
:
int
=
32
cuda_graph_max_bs
:
Optional
[
int
]
=
None
cuda_graph_max_bs
:
Optional
[
int
]
=
None
cuda_graph_bs
:
Optional
[
List
[
int
]]
=
None
torchao_config
:
str
=
""
torchao_config
:
str
=
""
enable_nan_detection
:
bool
=
False
enable_nan_detection
:
bool
=
False
enable_p2p_check
:
bool
=
False
enable_p2p_check
:
bool
=
False
...
@@ -803,6 +804,12 @@ class ServerArgs:
...
@@ -803,6 +804,12 @@ class ServerArgs:
default
=
ServerArgs
.
cuda_graph_max_bs
,
default
=
ServerArgs
.
cuda_graph_max_bs
,
help
=
"Set the maximum batch size for cuda graph."
,
help
=
"Set the maximum batch size for cuda graph."
,
)
)
parser
.
add_argument
(
"--cuda-graph-bs"
,
type
=
int
,
nargs
=
"+"
,
help
=
"Set the list of batch sizes for cuda graph."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--torchao-config"
,
"--torchao-config"
,
type
=
str
,
type
=
str
,
...
...
python/sglang/srt/utils.py
View file @
bdc1acf6
...
@@ -709,13 +709,14 @@ def broadcast_pyobj(
...
@@ -709,13 +709,14 @@ def broadcast_pyobj(
data
:
List
[
Any
],
data
:
List
[
Any
],
rank
:
int
,
rank
:
int
,
dist_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
dist_group
:
Optional
[
torch
.
distributed
.
ProcessGroup
]
=
None
,
src
:
int
=
0
,
):
):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if
rank
==
0
:
if
rank
==
0
:
if
len
(
data
)
==
0
:
if
len
(
data
)
==
0
:
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
dist
.
broadcast
(
tensor_size
,
src
=
src
,
group
=
dist_group
)
else
:
else
:
serialized_data
=
pickle
.
dumps
(
data
)
serialized_data
=
pickle
.
dumps
(
data
)
size
=
len
(
serialized_data
)
size
=
len
(
serialized_data
)
...
@@ -724,19 +725,19 @@ def broadcast_pyobj(
...
@@ -724,19 +725,19 @@ def broadcast_pyobj(
)
)
tensor_size
=
torch
.
tensor
([
size
],
dtype
=
torch
.
long
)
tensor_size
=
torch
.
tensor
([
size
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
dist
.
broadcast
(
tensor_size
,
src
=
src
,
group
=
dist_group
)
dist
.
broadcast
(
tensor_data
,
src
=
0
,
group
=
dist_group
)
dist
.
broadcast
(
tensor_data
,
src
=
src
,
group
=
dist_group
)
return
data
return
data
else
:
else
:
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
tensor_size
=
torch
.
tensor
([
0
],
dtype
=
torch
.
long
)
dist
.
broadcast
(
tensor_size
,
src
=
0
,
group
=
dist_group
)
dist
.
broadcast
(
tensor_size
,
src
=
src
,
group
=
dist_group
)
size
=
tensor_size
.
item
()
size
=
tensor_size
.
item
()
if
size
==
0
:
if
size
==
0
:
return
[]
return
[]
tensor_data
=
torch
.
empty
(
size
,
dtype
=
torch
.
uint8
)
tensor_data
=
torch
.
empty
(
size
,
dtype
=
torch
.
uint8
)
dist
.
broadcast
(
tensor_data
,
src
=
0
,
group
=
dist_group
)
dist
.
broadcast
(
tensor_data
,
src
=
src
,
group
=
dist_group
)
serialized_data
=
bytes
(
tensor_data
.
cpu
().
numpy
())
serialized_data
=
bytes
(
tensor_data
.
cpu
().
numpy
())
data
=
pickle
.
loads
(
serialized_data
)
data
=
pickle
.
loads
(
serialized_data
)
...
...
python/sglang/test/test_utils.py
View file @
bdc1acf6
...
@@ -532,6 +532,8 @@ def run_bench_serving(
...
@@ -532,6 +532,8 @@ def run_bench_serving(
request_rate
,
request_rate
,
other_server_args
,
other_server_args
,
dataset_name
=
"random"
,
dataset_name
=
"random"
,
dataset_path
=
""
,
tokenizer
=
None
,
random_input_len
=
4096
,
random_input_len
=
4096
,
random_output_len
=
2048
,
random_output_len
=
2048
,
disable_stream
=
False
,
disable_stream
=
False
,
...
@@ -553,9 +555,9 @@ def run_bench_serving(
...
@@ -553,9 +555,9 @@ def run_bench_serving(
host
=
None
,
host
=
None
,
port
=
None
,
port
=
None
,
dataset_name
=
dataset_name
,
dataset_name
=
dataset_name
,
dataset_path
=
""
,
dataset_path
=
dataset_path
,
model
=
None
,
model
=
None
,
tokenizer
=
None
,
tokenizer
=
tokenizer
,
num_prompts
=
num_prompts
,
num_prompts
=
num_prompts
,
sharegpt_output_len
=
None
,
sharegpt_output_len
=
None
,
random_input_len
=
random_input_len
,
random_input_len
=
random_input_len
,
...
@@ -657,16 +659,16 @@ STDERR_FILENAME = "stderr.txt"
...
@@ -657,16 +659,16 @@ STDERR_FILENAME = "stderr.txt"
STDOUT_FILENAME
=
"stdout.txt"
STDOUT_FILENAME
=
"stdout.txt"
def
read_output
(
output_lines
):
def
read_output
(
output_lines
:
List
[
str
],
filename
:
str
=
STDERR_FILENAME
):
"""Print the output in real time with another thread."""
"""Print the output in real time with another thread."""
while
not
os
.
path
.
exists
(
STDERR_FILENAME
):
while
not
os
.
path
.
exists
(
filename
):
time
.
sleep
(
1
)
time
.
sleep
(
1
)
pt
=
0
pt
=
0
while
pt
>=
0
:
while
pt
>=
0
:
if
pt
>
0
and
not
os
.
path
.
exists
(
STDERR_FILENAME
):
if
pt
>
0
and
not
os
.
path
.
exists
(
filename
):
break
break
lines
=
open
(
STDERR_FILENAME
).
readlines
()
lines
=
open
(
filename
).
readlines
()
for
line
in
lines
[
pt
:]:
for
line
in
lines
[
pt
:]:
print
(
line
,
end
=
""
,
flush
=
True
)
print
(
line
,
end
=
""
,
flush
=
True
)
output_lines
.
append
(
line
)
output_lines
.
append
(
line
)
...
@@ -747,6 +749,33 @@ def run_and_check_memory_leak(
...
@@ -747,6 +749,33 @@ def run_and_check_memory_leak(
assert
has_abort
assert
has_abort
def
run_command_and_capture_output
(
command
,
env
:
Optional
[
dict
]
=
None
):
stdout
=
open
(
STDOUT_FILENAME
,
"w"
)
stderr
=
open
(
STDERR_FILENAME
,
"w"
)
process
=
subprocess
.
Popen
(
command
,
stdout
=
stdout
,
stderr
=
stderr
,
env
=
env
,
text
=
True
)
# Launch a thread to stream the output
output_lines
=
[]
t
=
threading
.
Thread
(
target
=
read_output
,
args
=
(
output_lines
,
STDOUT_FILENAME
))
t
.
start
()
# Join the process
process
.
wait
()
stdout
.
close
()
stderr
.
close
()
if
os
.
path
.
exists
(
STDOUT_FILENAME
):
os
.
remove
(
STDOUT_FILENAME
)
if
os
.
path
.
exists
(
STDERR_FILENAME
):
os
.
remove
(
STDERR_FILENAME
)
kill_process_tree
(
process
.
pid
)
t
.
join
()
return
output_lines
def
run_mmlu_test
(
def
run_mmlu_test
(
disable_radix_cache
=
False
,
disable_radix_cache
=
False
,
enable_mixed_chunk
=
False
,
enable_mixed_chunk
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment