Unverified Commit bdc1acf6 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Misc fix for min_p_sampling, --cuda-graph-bs (#2761)

parent 6d08ce2a
...@@ -66,12 +66,14 @@ jobs: ...@@ -66,12 +66,14 @@ jobs:
- name: Run test - name: Run test
timeout-minutes: 25 timeout-minutes: 25
run: | run: |
cd test/srt
RANGE=${{ matrix.range }} RANGE=${{ matrix.range }}
range_begin=${RANGE%-*} range_begin=${RANGE%-*}
range_end=${RANGE#*-} range_end=${RANGE#*-}
cd test/srt
python3 run_suite.py --suite per-commit --range-begin ${range_begin} --range-end ${range_end} python3 run_suite.py --suite per-commit --range-begin ${range_begin} --range-end ${range_end}
unit-test-backend-2-gpu: unit-test-backend-2-gpu:
if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request'
runs-on: 2-gpu-runner runs-on: 2-gpu-runner
......
...@@ -228,6 +228,7 @@ class BenchmarkWorker: ...@@ -228,6 +228,7 @@ class BenchmarkWorker:
hidden_size, hidden_size,
topk, topk,
dtype_str, dtype_str,
False,
) )
else: else:
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
......
...@@ -16,14 +16,20 @@ classifiers = [ ...@@ -16,14 +16,20 @@ classifiers = [
dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"] dependencies = ["requests", "tqdm", "numpy", "IPython", "setproctitle"]
[project.optional-dependencies] [project.optional-dependencies]
runtime_common = ["aiohttp", "decord", "fastapi", runtime_common = [
"aiohttp", "decord", "fastapi",
"hf_transfer", "huggingface_hub", "interegular", "modelscope", "hf_transfer", "huggingface_hub", "interegular", "modelscope",
"orjson", "outlines>=0.0.44,<0.1.0", "orjson", "outlines>=0.0.44,<0.1.0",
"packaging", "pillow", "prometheus-client>=0.20.0", "packaging", "pillow", "prometheus-client>=0.20.0",
"psutil", "pydantic", "python-multipart", "psutil", "pydantic", "python-multipart",
"pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop", "pyzmq>=25.1.2", "torchao>=0.7.0", "uvicorn", "uvloop",
"xgrammar>=0.1.6"] "xgrammar>=0.1.6"
srt = ["sglang[runtime_common]", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1", "cuda-python", "flashinfer==0.1.6", "sgl-kernel>=0.0.2.post11"] ]
srt = [
"sglang[runtime_common]", "cuda-python",
"sgl-kernel>=0.0.2.post11", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1",
"flashinfer==0.1.6"
]
# HIP (Heterogeneous-computing Interface for Portability) for AMD # HIP (Heterogeneous-computing Interface for Portability) for AMD
# => base docker rocm/vllm-dev:20241022, not from public vllm whl # => base docker rocm/vllm-dev:20241022, not from public vllm whl
......
...@@ -563,7 +563,7 @@ def sample_sharegpt_requests( ...@@ -563,7 +563,7 @@ def sample_sharegpt_requests(
raise ValueError("output_len too small") raise ValueError("output_len too small")
# Download sharegpt if necessary # Download sharegpt if necessary
if not os.path.isfile(dataset_path): if not os.path.isfile(dataset_path) and dataset_path == "":
dataset_path = download_and_cache_file(SHAREGPT_URL) dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset. # Load the dataset.
...@@ -1064,8 +1064,11 @@ async def benchmark( ...@@ -1064,8 +1064,11 @@ async def benchmark(
"total_output_tokens_retokenized": metrics.total_output_retokenized, "total_output_tokens_retokenized": metrics.total_output_retokenized,
"mean_e2e_latency_ms": metrics.mean_e2e_latency_ms, "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
"median_e2e_latency_ms": metrics.median_e2e_latency_ms, "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
"mean_ttft_ms": metrics.mean_ttft_ms,
"median_ttft_ms": metrics.median_ttft_ms, "median_ttft_ms": metrics.median_ttft_ms,
"mean_itl_ms": metrics.mean_itl_ms,
"median_itl_ms": metrics.median_itl_ms, "median_itl_ms": metrics.median_itl_ms,
"input_throughput": metrics.input_throughput,
"output_throughput": metrics.output_throughput, "output_throughput": metrics.output_throughput,
"sharegpt_output_len": args.sharegpt_output_len, "sharegpt_output_len": args.sharegpt_output_len,
"random_input_len": args.random_input_len, "random_input_len": args.random_input_len,
......
...@@ -117,6 +117,11 @@ class LogitsProcessor(nn.Module): ...@@ -117,6 +117,11 @@ class LogitsProcessor(nn.Module):
self.final_logit_softcapping = getattr( self.final_logit_softcapping = getattr(
self.config, "final_logit_softcapping", None self.config, "final_logit_softcapping", None
) )
if (
self.final_logit_softcapping is not None
and self.final_logit_softcapping < 0
):
self.final_logit_softcapping = None
def forward( def forward(
self, self,
......
...@@ -1011,6 +1011,17 @@ def fused_experts_impl( ...@@ -1011,6 +1011,17 @@ def fused_experts_impl(
out_hidden_states[begin_chunk_idx:end_chunk_idx], out_hidden_states[begin_chunk_idx:end_chunk_idx],
) )
else: else:
if topk_ids.shape[1] == 1:
out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_(
intermediate_cache3[:, 0]
)
elif topk_ids.shape[1] == 2:
torch.add(
intermediate_cache3[:, 0],
intermediate_cache3[:, 1],
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
).squeeze(dim=1)
elif topk_ids.shape[1] > 2:
torch.sum( torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape), intermediate_cache3.view(*intermediate_cache3.shape),
dim=1, dim=1,
......
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
from typing import Callable, Dict, Optional, Type from typing import Dict, Type
import torch
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
......
...@@ -20,6 +20,7 @@ import threading ...@@ -20,6 +20,7 @@ import threading
from enum import Enum, auto from enum import Enum, auto
import psutil import psutil
import setproctitle
import zmq import zmq
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
...@@ -230,6 +231,7 @@ def run_data_parallel_controller_process( ...@@ -230,6 +231,7 @@ def run_data_parallel_controller_process(
port_args: PortArgs, port_args: PortArgs,
pipe_writer, pipe_writer,
): ):
setproctitle.setproctitle("sglang::data_parallel_controller")
configure_logger(server_args) configure_logger(server_args)
parent_process = psutil.Process().parent() parent_process = psutil.Process().parent()
......
...@@ -1516,8 +1516,9 @@ class Scheduler: ...@@ -1516,8 +1516,9 @@ class Scheduler:
return success, message return success, message
def update_weights_from_distributed( def update_weights_from_distributed(
self, recv_req: UpdateWeightsFromDistributedReqInput self,
): recv_req: UpdateWeightsFromDistributedReqInput,
) -> Tuple[bool, str]:
"""Update the online model parameter.""" """Update the online model parameter."""
success, message = self.tp_worker.update_weights_from_distributed(recv_req) success, message = self.tp_worker.update_weights_from_distributed(recv_req)
if success: if success:
......
...@@ -114,26 +114,20 @@ class TokenizerMetricsCollector: ...@@ -114,26 +114,20 @@ class TokenizerMetricsCollector:
documentation="Histogram of time to first token in seconds.", documentation="Histogram of time to first token in seconds.",
labelnames=labels.keys(), labelnames=labels.keys(),
buckets=[ buckets=[
0.001,
0.005,
0.01,
0.02,
0.04,
0.06,
0.08,
0.1, 0.1,
0.25, 0.25,
0.5, 0.5,
0.75, 0.75,
1.0, 1,
2.5, 2,
5.0, 5,
7.5, 10,
10.0, 20,
15.0, 40,
20.0, 60,
25.0, 80,
30.0, 120,
160,
], ],
) )
...@@ -168,21 +162,19 @@ class TokenizerMetricsCollector: ...@@ -168,21 +162,19 @@ class TokenizerMetricsCollector:
documentation="Histogram of End-to-end request latency in seconds", documentation="Histogram of End-to-end request latency in seconds",
labelnames=labels.keys(), labelnames=labels.keys(),
buckets=[ buckets=[
0.3, 0.1,
0.25,
0.5, 0.5,
0.8, 1,
1.0, 2,
1.5, 5,
2.0, 10,
2.5, 20,
5.0, 40,
10.0, 60,
15.0, 80,
20.0, 120,
30.0, 160,
40.0,
50.0,
60.0,
], ],
) )
......
...@@ -124,6 +124,13 @@ class CudaGraphRunner: ...@@ -124,6 +124,13 @@ class CudaGraphRunner:
self.tp_size = self.model_runner.tp_size self.tp_size = self.model_runner.tp_size
# Batch sizes to capture # Batch sizes to capture
self.capture_bs = self.model_runner.server_args.cuda_graph_bs
if self.capture_bs is None:
if model_runner.server_args.disable_cuda_graph_padding:
self.capture_bs = list(range(1, 33)) + [64, 128]
else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
if model_runner.server_args.disable_cuda_graph_padding: if model_runner.server_args.disable_cuda_graph_padding:
self.capture_bs = list(range(1, 33)) + [64, 128] self.capture_bs = list(range(1, 33)) + [64, 128]
else: else:
...@@ -340,8 +347,8 @@ class CudaGraphRunner: ...@@ -340,8 +347,8 @@ class CudaGraphRunner:
top_logprobs_nums=[0] * bs, top_logprobs_nums=[0] * bs,
positions=positions, positions=positions,
global_num_tokens=global_num_tokens, global_num_tokens=global_num_tokens,
mrope_positions=mrope_positions,
gathered_buffer=gathered_buffer, gathered_buffer=gathered_buffer,
mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info, spec_info=spec_info,
capture_hidden_mode=( capture_hidden_mode=(
......
...@@ -89,6 +89,7 @@ class ModelRunner: ...@@ -89,6 +89,7 @@ class ModelRunner:
self.is_draft_worker = is_draft_worker self.is_draft_worker = is_draft_worker
self.is_generation = model_config.is_generation self.is_generation = model_config.is_generation
self.is_multimodal = model_config.is_multimodal self.is_multimodal = model_config.is_multimodal
self.should_log = tp_rank == 0
self.spec_algorithm = SpeculativeAlgorithm.from_string( self.spec_algorithm = SpeculativeAlgorithm.from_string(
server_args.speculative_algorithm server_args.speculative_algorithm
) )
...@@ -117,15 +118,21 @@ class ModelRunner: ...@@ -117,15 +118,21 @@ class ModelRunner:
if self.is_multimodal: if self.is_multimodal:
self.mem_fraction_static *= 0.95 self.mem_fraction_static *= 0.95
logger.info(
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
f"because this is a multimodal model."
)
if self.model_config.hf_config.architectures == [ if self.model_config.hf_config.architectures == [
"MllamaForConditionalGeneration" "MllamaForConditionalGeneration"
]: ]:
logger.info("Automatically turn off --chunked-prefill-size for mllama.") logger.info("Automatically turn off --chunked-prefill-size for mllama.")
server_args.chunked_prefill_size = -1 server_args.chunked_prefill_size = -1
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
if self.model_config.hf_config.architectures == [ if self.model_config.hf_config.architectures == [
"Qwen2VLForConditionalGeneration" "Qwen2VLForConditionalGeneration"
]: ]:
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
logger.info( logger.info(
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl." "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
) )
......
...@@ -232,6 +232,7 @@ class SamplingBatchInfo: ...@@ -232,6 +232,7 @@ class SamplingBatchInfo:
self.logit_bias = SamplingBatchInfo.merge_bias_tensor( self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
self.logit_bias, other.logit_bias, len(self), len(other), self.device self.logit_bias, other.logit_bias, len(self), len(other), self.device
) )
self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
def apply_logits_bias(self, logits: torch.Tensor): def apply_logits_bias(self, logits: torch.Tensor):
# Apply logit_bias # Apply logit_bias
......
...@@ -127,14 +127,12 @@ async def health() -> Response: ...@@ -127,14 +127,12 @@ async def health() -> Response:
async def health_generate(request: Request) -> Response: async def health_generate(request: Request) -> Response:
"""Check the health of the inference server by generating one token.""" """Check the health of the inference server by generating one token."""
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
if tokenizer_manager.is_generation: if tokenizer_manager.is_generation:
gri = GenerateReqInput( gri = GenerateReqInput(input_ids=[0], sampling_params=sampling_params)
input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
)
else: else:
gri = EmbeddingReqInput( gri = EmbeddingReqInput(input_ids=[0], sampling_params=sampling_params)
input_ids=[0], sampling_params={"max_new_tokens": 1, "temperature": 0.7}
)
try: try:
async for _ in tokenizer_manager.generate_request(gri, request): async for _ in tokenizer_manager.generate_request(gri, request):
......
...@@ -148,6 +148,7 @@ class ServerArgs: ...@@ -148,6 +148,7 @@ class ServerArgs:
enable_torch_compile: bool = False enable_torch_compile: bool = False
torch_compile_max_bs: int = 32 torch_compile_max_bs: int = 32
cuda_graph_max_bs: Optional[int] = None cuda_graph_max_bs: Optional[int] = None
cuda_graph_bs: Optional[List[int]] = None
torchao_config: str = "" torchao_config: str = ""
enable_nan_detection: bool = False enable_nan_detection: bool = False
enable_p2p_check: bool = False enable_p2p_check: bool = False
...@@ -803,6 +804,12 @@ class ServerArgs: ...@@ -803,6 +804,12 @@ class ServerArgs:
default=ServerArgs.cuda_graph_max_bs, default=ServerArgs.cuda_graph_max_bs,
help="Set the maximum batch size for cuda graph.", help="Set the maximum batch size for cuda graph.",
) )
parser.add_argument(
"--cuda-graph-bs",
type=int,
nargs="+",
help="Set the list of batch sizes for cuda graph.",
)
parser.add_argument( parser.add_argument(
"--torchao-config", "--torchao-config",
type=str, type=str,
......
...@@ -709,13 +709,14 @@ def broadcast_pyobj( ...@@ -709,13 +709,14 @@ def broadcast_pyobj(
data: List[Any], data: List[Any],
rank: int, rank: int,
dist_group: Optional[torch.distributed.ProcessGroup] = None, dist_group: Optional[torch.distributed.ProcessGroup] = None,
src: int = 0,
): ):
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend.""" """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
if rank == 0: if rank == 0:
if len(data) == 0: if len(data) == 0:
tensor_size = torch.tensor([0], dtype=torch.long) tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group) dist.broadcast(tensor_size, src=src, group=dist_group)
else: else:
serialized_data = pickle.dumps(data) serialized_data = pickle.dumps(data)
size = len(serialized_data) size = len(serialized_data)
...@@ -724,19 +725,19 @@ def broadcast_pyobj( ...@@ -724,19 +725,19 @@ def broadcast_pyobj(
) )
tensor_size = torch.tensor([size], dtype=torch.long) tensor_size = torch.tensor([size], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group) dist.broadcast(tensor_size, src=src, group=dist_group)
dist.broadcast(tensor_data, src=0, group=dist_group) dist.broadcast(tensor_data, src=src, group=dist_group)
return data return data
else: else:
tensor_size = torch.tensor([0], dtype=torch.long) tensor_size = torch.tensor([0], dtype=torch.long)
dist.broadcast(tensor_size, src=0, group=dist_group) dist.broadcast(tensor_size, src=src, group=dist_group)
size = tensor_size.item() size = tensor_size.item()
if size == 0: if size == 0:
return [] return []
tensor_data = torch.empty(size, dtype=torch.uint8) tensor_data = torch.empty(size, dtype=torch.uint8)
dist.broadcast(tensor_data, src=0, group=dist_group) dist.broadcast(tensor_data, src=src, group=dist_group)
serialized_data = bytes(tensor_data.cpu().numpy()) serialized_data = bytes(tensor_data.cpu().numpy())
data = pickle.loads(serialized_data) data = pickle.loads(serialized_data)
......
...@@ -532,6 +532,8 @@ def run_bench_serving( ...@@ -532,6 +532,8 @@ def run_bench_serving(
request_rate, request_rate,
other_server_args, other_server_args,
dataset_name="random", dataset_name="random",
dataset_path="",
tokenizer=None,
random_input_len=4096, random_input_len=4096,
random_output_len=2048, random_output_len=2048,
disable_stream=False, disable_stream=False,
...@@ -553,9 +555,9 @@ def run_bench_serving( ...@@ -553,9 +555,9 @@ def run_bench_serving(
host=None, host=None,
port=None, port=None,
dataset_name=dataset_name, dataset_name=dataset_name,
dataset_path="", dataset_path=dataset_path,
model=None, model=None,
tokenizer=None, tokenizer=tokenizer,
num_prompts=num_prompts, num_prompts=num_prompts,
sharegpt_output_len=None, sharegpt_output_len=None,
random_input_len=random_input_len, random_input_len=random_input_len,
...@@ -657,16 +659,16 @@ STDERR_FILENAME = "stderr.txt" ...@@ -657,16 +659,16 @@ STDERR_FILENAME = "stderr.txt"
STDOUT_FILENAME = "stdout.txt" STDOUT_FILENAME = "stdout.txt"
def read_output(output_lines): def read_output(output_lines: List[str], filename: str = STDERR_FILENAME):
"""Print the output in real time with another thread.""" """Print the output in real time with another thread."""
while not os.path.exists(STDERR_FILENAME): while not os.path.exists(filename):
time.sleep(1) time.sleep(1)
pt = 0 pt = 0
while pt >= 0: while pt >= 0:
if pt > 0 and not os.path.exists(STDERR_FILENAME): if pt > 0 and not os.path.exists(filename):
break break
lines = open(STDERR_FILENAME).readlines() lines = open(filename).readlines()
for line in lines[pt:]: for line in lines[pt:]:
print(line, end="", flush=True) print(line, end="", flush=True)
output_lines.append(line) output_lines.append(line)
...@@ -747,6 +749,33 @@ def run_and_check_memory_leak( ...@@ -747,6 +749,33 @@ def run_and_check_memory_leak(
assert has_abort assert has_abort
def run_command_and_capture_output(command, env: Optional[dict] = None):
stdout = open(STDOUT_FILENAME, "w")
stderr = open(STDERR_FILENAME, "w")
process = subprocess.Popen(
command, stdout=stdout, stderr=stderr, env=env, text=True
)
# Launch a thread to stream the output
output_lines = []
t = threading.Thread(target=read_output, args=(output_lines, STDOUT_FILENAME))
t.start()
# Join the process
process.wait()
stdout.close()
stderr.close()
if os.path.exists(STDOUT_FILENAME):
os.remove(STDOUT_FILENAME)
if os.path.exists(STDERR_FILENAME):
os.remove(STDERR_FILENAME)
kill_process_tree(process.pid)
t.join()
return output_lines
def run_mmlu_test( def run_mmlu_test(
disable_radix_cache=False, disable_radix_cache=False,
enable_mixed_chunk=False, enable_mixed_chunk=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment