Unverified Commit f6af3a65 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Cleanup readme, llava examples, usage examples and nccl init (#1194)

parent c9064e6f
...@@ -4,7 +4,7 @@ Usage: ...@@ -4,7 +4,7 @@ Usage:
# Installing latest sglang. # Installing latest sglang.
# Endpoint Service CLI: # Endpoint Service CLI:
# python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --tokenizer-path lmms-lab/llavanext-qwen-tokenizer --port=30000 --host="127.0.0.1" --tp-size=4 python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8
python3 http_qwen_llava_test.py python3 http_qwen_llava_test.py
...@@ -16,7 +16,6 @@ import argparse ...@@ -16,7 +16,6 @@ import argparse
import asyncio import asyncio
import copy import copy
import json import json
import time
import aiohttp import aiohttp
import requests import requests
......
"""
Usage: python3 srt_example_llava.py
"""
from PIL import ImageFile
import sglang as sgl
from sglang.lang.chat_template import get_chat_template
from sglang.srt.utils import load_image
ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images
@sgl.function
def image_qa(s, image, question):
s += sgl.user(sgl.image(image) + question)
s += sgl.assistant(sgl.gen("answer"))
def single():
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
pil_image, _ = load_image(image_url)
state = image_qa.run(image=pil_image, question="What is this?", max_new_tokens=512)
print(state["answer"], "\n")
def stream():
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
pil_image, _ = load_image(image_url)
state = image_qa.run(
image=pil_image,
question="Please generate short caption for this image.",
max_new_tokens=512,
temperature=0,
stream=True,
)
for out in state.text_iter("answer"):
print(out, end="", flush=True)
print()
def batch():
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
pil_image, _ = load_image(image_url)
states = image_qa.run_batch(
[
{"image": pil_image, "question": "What is this?"},
{"image": pil_image, "question": "What is this?"},
],
max_new_tokens=512,
)
for s in states:
print(s["answer"], "\n")
if __name__ == "__main__":
import multiprocessing as mp
mp.set_start_method("spawn", force=True)
runtime = sgl.Runtime(
model_path="lmms-lab/llama3-llava-next-8b",
tokenizer_path="lmms-lab/llama3-llava-next-8b-tokenizer",
)
runtime.endpoint.chat_template = get_chat_template("llama-3-instruct")
# runtime = sgl.Runtime(
# model_path="lmms-lab/llava-next-72b",
# tokenizer_path="lmms-lab/llavanext-qwen-tokenizer",
# )
# runtime.endpoint.chat_template = get_chat_template("chatml-llava")
sgl.set_default_backend(runtime)
print(f"chat template: {runtime.endpoint.chat_template.name}")
# Or you can use API models
# sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview"))
# sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))
# Run a single request
print("\n========== single ==========\n")
single()
# Stream output
print("\n========== stream ==========\n")
stream()
# Run a batch of requests
print("\n========== batch ==========\n")
batch()
runtime.shutdown()
...@@ -111,7 +111,11 @@ def load_model(server_args, tp_rank): ...@@ -111,7 +111,11 @@ def load_model(server_args, tp_rank):
suppress_other_loggers() suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
model_config = ModelConfig(path=server_args.model_path) model_config = ModelConfig(
server_args.model_path,
server_args.trust_remote_code,
context_length=server_args.context_length,
)
model_runner = ModelRunner( model_runner = ModelRunner(
model_config=model_config, model_config=model_config,
mem_fraction_static=server_args.mem_fraction_static, mem_fraction_static=server_args.mem_fraction_static,
......
from dataclasses import dataclass, field from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import Callable, Dict, List, Optional, Tuple from typing import Callable, Dict, List, Tuple
class ChatTemplateStyle(Enum): class ChatTemplateStyle(Enum):
......
"""Launch the inference server for Llava-video model."""
import argparse
from sglang.srt.server import ServerArgs, launch_server
if __name__ == "__main__":
model_overide_args = {}
model_overide_args["mm_spatial_pool_stride"] = 2
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
model_overide_args["num_frames"] = 16
model_overide_args["model_type"] = "llavavid"
if model_overide_args["num_frames"] == 32:
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
model_overide_args["max_sequence_length"] = 4096 * 2
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
model_overide_args["model_max_length"] = 4096 * 2
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
if "34b" in args.model_path.lower():
model_overide_args["image_token_index"] = 64002
server_args = ServerArgs.from_cli_args(args)
launch_server(server_args, model_overide_args, None)
...@@ -26,7 +26,7 @@ import triton.language as tl ...@@ -26,7 +26,7 @@ import triton.language as tl
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
if global_server_args_dict.get("attention_reduce_in_fp32", False): if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32 REDUCE_TRITON_TYPE = tl.float32
REDUCE_TORCH_TYPE = torch.float32 REDUCE_TORCH_TYPE = torch.float32
else: else:
......
...@@ -239,7 +239,7 @@ class FusedMoE(torch.nn.Module): ...@@ -239,7 +239,7 @@ class FusedMoE(torch.nn.Module):
weight_name: str, weight_name: str,
shard_id: int, shard_id: int,
expert_id: int, expert_id: int,
pre_sharded: bool, use_presharded_weights: bool = False,
): ):
param_data = param.data param_data = param.data
...@@ -273,7 +273,7 @@ class FusedMoE(torch.nn.Module): ...@@ -273,7 +273,7 @@ class FusedMoE(torch.nn.Module):
else: else:
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
shard_size = self.intermediate_size_per_partition shard_size = self.intermediate_size_per_partition
if pre_sharded: if use_presharded_weights:
shard = slice(None) shard = slice(None)
else: else:
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
......
...@@ -180,7 +180,7 @@ class LogitsProcessor(nn.Module): ...@@ -180,7 +180,7 @@ class LogitsProcessor(nn.Module):
if hasattr(self.config, "final_logit_softcapping"): if hasattr(self.config, "final_logit_softcapping"):
last_logits.div_(self.config.final_logit_softcapping) last_logits.div_(self.config.final_logit_softcapping)
last_logits = torch.tanh(last_logits) torch.tanh(last_logits, out=last_logits)
last_logits.mul_(self.config.final_logit_softcapping) last_logits.mul_(self.config.final_logit_softcapping)
# Return only last_logits if logprob is not requested # Return only last_logits if logprob is not requested
...@@ -241,7 +241,7 @@ class LogitsProcessor(nn.Module): ...@@ -241,7 +241,7 @@ class LogitsProcessor(nn.Module):
if hasattr(self.config, "final_logit_softcapping"): if hasattr(self.config, "final_logit_softcapping"):
all_logits.div_(self.config.final_logit_softcapping) all_logits.div_(self.config.final_logit_softcapping)
all_logits = torch.tanh(all_logits) torch.tanh(all_logits, out=all_logits)
all_logits.mul_(self.config.final_logit_softcapping) all_logits.mul_(self.config.final_logit_softcapping)
all_logprobs = all_logits all_logprobs = all_logits
......
...@@ -35,7 +35,7 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 ...@@ -35,7 +35,7 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
global_server_args_dict = { global_server_args_dict = {
"disable_flashinfer": False, "disable_flashinfer": False,
"disable_flashinfer_sampling": False, "disable_flashinfer_sampling": False,
"attention_reduce_in_fp32": False, "triton_attention_reduce_in_fp32": False,
"enable_mla": False, "enable_mla": False,
} }
......
...@@ -606,6 +606,9 @@ class TokenizerManager: ...@@ -606,6 +606,9 @@ class TokenizerManager:
return background_tasks return background_tasks
def create_handle_loop(self): def create_handle_loop(self):
if not self.to_create_loop:
return
self.to_create_loop = False self.to_create_loop = False
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop()) loop.create_task(self.handle_loop())
......
...@@ -20,7 +20,6 @@ import importlib ...@@ -20,7 +20,6 @@ import importlib
import importlib.resources import importlib.resources
import logging import logging
import pkgutil import pkgutil
import warnings
from functools import lru_cache from functools import lru_cache
from typing import Optional, Type from typing import Optional, Type
...@@ -91,23 +90,35 @@ class ModelRunner: ...@@ -91,23 +90,35 @@ class ModelRunner:
{ {
"disable_flashinfer": server_args.disable_flashinfer, "disable_flashinfer": server_args.disable_flashinfer,
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling, "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"enable_mla": server_args.enable_mla, "enable_mla": server_args.enable_mla,
} }
) )
min_per_gpu_memory = self.init_torch_distributed()
self.load_model()
self.init_memory_pool(
min_per_gpu_memory,
server_args.max_num_reqs,
server_args.max_total_tokens,
)
self.init_cublas()
self.init_flashinfer()
self.init_cuda_graphs()
def init_torch_distributed(self):
# Init torch distributed # Init torch distributed
torch.cuda.set_device(self.gpu_id) torch.cuda.set_device(self.gpu_id)
logger.info(f"[gpu={self.gpu_id}] Init nccl begin.") logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
if not server_args.enable_p2p_check: if not self.server_args.enable_p2p_check:
monkey_patch_vllm_p2p_access_check(self.gpu_id) monkey_patch_vllm_p2p_access_check(self.gpu_id)
if server_args.nccl_init_addr: if self.server_args.nccl_init_addr:
nccl_init_method = f"tcp://{server_args.nccl_init_addr}" nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
else: else:
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}" nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
set_custom_all_reduce(not server_args.disable_custom_all_reduce) set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
init_distributed_environment( init_distributed_environment(
backend="nccl", backend="nccl",
world_size=self.tp_size, world_size=self.tp_size,
...@@ -116,32 +127,28 @@ class ModelRunner: ...@@ -116,32 +127,28 @@ class ModelRunner:
distributed_init_method=nccl_init_method, distributed_init_method=nccl_init_method,
) )
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
total_gpu_memory = get_available_gpu_memory( min_per_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1 self.gpu_id, distributed=self.tp_size > 1
) )
self.tp_group = get_tp_group() self.tp_group = get_tp_group()
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
# so we disable padding in cuda graph.
if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
self.server_args.disable_cuda_graph_padding = True
logger.info(
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
)
# Check memory for tensor parallelism
if self.tp_size > 1: if self.tp_size > 1:
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id) local_gpu_memory = get_available_gpu_memory(self.gpu_id)
if total_local_gpu_memory < total_gpu_memory * 0.9: if min_per_gpu_memory < local_gpu_memory * 0.9:
raise ValueError( raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes." "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
) )
# Load the model and create memory pool return min_per_gpu_memory
self.load_model()
self.init_memory_pool(
total_gpu_memory,
server_args.max_num_reqs,
server_args.max_total_tokens,
)
self.init_cublas()
self.init_flashinfer()
if self.is_generation:
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
# Capture cuda graphs
self.init_cuda_graphs()
def load_model(self): def load_model(self):
logger.info( logger.info(
...@@ -150,7 +157,7 @@ class ModelRunner: ...@@ -150,7 +157,7 @@ class ModelRunner:
) )
if torch.cuda.get_device_capability()[0] < 8: if torch.cuda.get_device_capability()[0] < 8:
logger.info( logger.info(
"Compute capability below sm80 use float16 due to lack of bfloat16 support." "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
) )
self.server_args.dtype = "float16" self.server_args.dtype = "float16"
...@@ -168,8 +175,9 @@ class ModelRunner: ...@@ -168,8 +175,9 @@ class ModelRunner:
skip_tokenizer_init=True, skip_tokenizer_init=True,
) )
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
# Drop this after Sept, 2024.
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8: if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self.model_config.hf_config.num_key_value_heads = 8 self.model_config.hf_config.num_key_value_heads = 8
self.vllm_model_config.hf_config.num_key_value_heads = 8 self.vllm_model_config.hf_config.num_key_value_heads = 8
monkey_patch_vllm_qvk_linear_loader() monkey_patch_vllm_qvk_linear_loader()
...@@ -191,8 +199,8 @@ class ModelRunner: ...@@ -191,8 +199,8 @@ class ModelRunner:
cache_config=None, cache_config=None,
) )
self.sliding_window_size = ( self.sliding_window_size = (
self.model.get_window_size() self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_window_size") if hasattr(self.model, "get_attention_sliding_window_size")
else None else None
) )
self.is_generation = is_generation_model( self.is_generation = is_generation_model(
...@@ -206,7 +214,8 @@ class ModelRunner: ...@@ -206,7 +214,8 @@ class ModelRunner:
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB" f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
) )
def update_weights(self, model_path, load_format): def update_weights(self, model_path: str, load_format: str):
"""Update weights in-place."""
from vllm.model_executor.model_loader.loader import ( from vllm.model_executor.model_loader.loader import (
DefaultModelLoader, DefaultModelLoader,
device_loading_context, device_loading_context,
...@@ -222,6 +231,7 @@ class ModelRunner: ...@@ -222,6 +231,7 @@ class ModelRunner:
target_device = torch.device(self.device_config.device) target_device = torch.device(self.device_config.device)
try: try:
# TODO: Use a better method to check this
vllm_model_config = VllmModelConfig( vllm_model_config = VllmModelConfig(
model=model_path, model=model_path,
quantization=self.server_args.quantization, quantization=self.server_args.quantization,
...@@ -291,7 +301,7 @@ class ModelRunner: ...@@ -291,7 +301,7 @@ class ModelRunner:
logger.info(f"[gpu={self.gpu_id}] Update weights end.") logger.info(f"[gpu={self.gpu_id}] Update weights end.")
return True, "Succeeded to update model weights" return True, "Succeeded to update model weights"
def profile_max_num_token(self, total_gpu_memory): def profile_max_num_token(self, total_gpu_memory: int):
available_gpu_memory = get_available_gpu_memory( available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1 self.gpu_id, distributed=self.tp_size > 1
) )
...@@ -319,7 +329,10 @@ class ModelRunner: ...@@ -319,7 +329,10 @@ class ModelRunner:
return max_num_token return max_num_token
def init_memory_pool( def init_memory_pool(
self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None self,
total_gpu_memory: int,
max_num_reqs: int = None,
max_total_tokens: int = None,
): ):
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory) self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if max_total_tokens is not None: if max_total_tokens is not None:
...@@ -388,6 +401,7 @@ class ModelRunner: ...@@ -388,6 +401,7 @@ class ModelRunner:
return c return c
def init_flashinfer(self): def init_flashinfer(self):
"""Init flashinfer attention kernel wrappers."""
if self.server_args.disable_flashinfer: if self.server_args.disable_flashinfer:
assert ( assert (
self.sliding_window_size is None self.sliding_window_size is None
...@@ -448,6 +462,11 @@ class ModelRunner: ...@@ -448,6 +462,11 @@ class ModelRunner:
) )
def init_cuda_graphs(self): def init_cuda_graphs(self):
"""Capture cuda graphs."""
if not self.is_generation:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
return
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer: if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
...@@ -457,7 +476,12 @@ class ModelRunner: ...@@ -457,7 +476,12 @@ class ModelRunner:
logger.info( logger.info(
f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes." f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
) )
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
if self.server_args.disable_cuda_graph_padding:
batch_size_list = list(range(1, 32)) + [64, 128]
else:
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.cuda_graph_runner = CudaGraphRunner( self.cuda_graph_runner = CudaGraphRunner(
self, self,
max_batch_size_to_capture=max(batch_size_list), max_batch_size_to_capture=max(batch_size_list),
......
...@@ -46,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import InputMetadata ...@@ -46,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import InputMetadata
# Aligned with HF's implementation, using sliding window inclusive with the last token # Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive # SGLang assumes exclusive
def get_window_size(config): def get_attention_sliding_window_size(config):
return config.sliding_window - 1 return config.sliding_window - 1
...@@ -213,7 +213,11 @@ class Gemma2Attention(nn.Module): ...@@ -213,7 +213,11 @@ class Gemma2Attention(nn.Module):
self.scaling, self.scaling,
num_kv_heads=self.num_kv_heads, num_kv_heads=self.num_kv_heads,
layer_id=layer_idx, layer_id=layer_idx,
sliding_window_size=get_window_size(config) if use_sliding_window else None, sliding_window_size=(
get_attention_sliding_window_size(config)
if use_sliding_window
else None
),
logit_cap=self.config.attn_logit_softcapping, logit_cap=self.config.attn_logit_softcapping,
) )
...@@ -406,8 +410,8 @@ class Gemma2ForCausalLM(nn.Module): ...@@ -406,8 +410,8 @@ class Gemma2ForCausalLM(nn.Module):
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
) )
def get_window_size(self): def get_attention_sliding_window_size(self):
return get_window_size(self.config) return get_attention_sliding_window_size(self.config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -295,12 +295,14 @@ class Grok1ModelForCausalLM(nn.Module): ...@@ -295,12 +295,14 @@ class Grok1ModelForCausalLM(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = Grok1Model(config, quant_config=quant_config) self.model = Grok1Model(config, quant_config=quant_config)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ReplicatedLinear(config.hidden_size, config.vocab_size) self.logits_processor = LogitsProcessor(config)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
# Monkey patch _prepare_weights to load pre-sharded weights # Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
self.use_presharded_weights = True
warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore", category=FutureWarning)
def forward( def forward(
...@@ -356,6 +358,13 @@ class Grok1ModelForCausalLM(nn.Module): ...@@ -356,6 +358,13 @@ class Grok1ModelForCausalLM(nn.Module):
continue continue
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if self.use_presharded_weights:
extra_kwargs = {
"use_presharded_weights": self.use_presharded_weights
}
else:
extra_kwargs = {}
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader( weight_loader(
...@@ -364,7 +373,7 @@ class Grok1ModelForCausalLM(nn.Module): ...@@ -364,7 +373,7 @@ class Grok1ModelForCausalLM(nn.Module):
weight_name, weight_name,
shard_id=shard_id, shard_id=shard_id,
expert_id=expert_id, expert_id=expert_id,
pre_sharded=get_tensor_model_parallel_world_size() > 1, **extra_kwargs,
) )
break break
else: else:
......
...@@ -81,13 +81,12 @@ class ServerArgs: ...@@ -81,13 +81,12 @@ class ServerArgs:
disable_cuda_graph: bool = False disable_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False disable_cuda_graph_padding: bool = False
disable_disk_cache: bool = False disable_disk_cache: bool = False
disable_custom_all_reduce: bool = False
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_torch_compile: bool = False enable_torch_compile: bool = False
enable_p2p_check: bool = False enable_p2p_check: bool = False
enable_mla: bool = False enable_mla: bool = False
attention_reduce_in_fp32: bool = False triton_attention_reduce_in_fp32: bool = False
efficient_weight_load: bool = False
disable_custom_all_reduce: bool = False
# Distributed args # Distributed args
nccl_init_addr: Optional[str] = None nccl_init_addr: Optional[str] = None
...@@ -404,6 +403,12 @@ class ServerArgs: ...@@ -404,6 +403,12 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.", help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
) )
parser.add_argument(
"--disable-custom-all-reduce",
action="store_true",
default=False,
help="Disable the custom all-reduce kernel and fall back to NCCL.",
)
parser.add_argument( parser.add_argument(
"--enable-mixed-chunk", "--enable-mixed-chunk",
action="store_true", action="store_true",
...@@ -425,7 +430,7 @@ class ServerArgs: ...@@ -425,7 +430,7 @@ class ServerArgs:
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.", help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
) )
parser.add_argument( parser.add_argument(
"--attention-reduce-in-fp32", "--triton-attention-reduce-in-fp32",
action="store_true", action="store_true",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels.", "This only affects Triton attention kernels.",
...@@ -435,12 +440,6 @@ class ServerArgs: ...@@ -435,12 +440,6 @@ class ServerArgs:
action="store_true", action="store_true",
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
) )
parser.add_argument(
"--disable-custom-all-reduce",
action="store_true",
default=False,
help="Disable the custom all-reduce kernel and fall back to NCCL.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
...@@ -347,7 +347,7 @@ def suppress_other_loggers(): ...@@ -347,7 +347,7 @@ def suppress_other_loggers():
logging.WARN logging.WARN
) )
logging.getLogger("vllm.selector").setLevel(logging.WARN) logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN) logging.getLogger("vllm.utils").setLevel(logging.ERROR)
def assert_pkg_version(pkg: str, min_version: str, message: str): def assert_pkg_version(pkg: str, min_version: str, message: str):
...@@ -451,10 +451,6 @@ def monkey_patch_vllm_dummy_weight_loader(): ...@@ -451,10 +451,6 @@ def monkey_patch_vllm_dummy_weight_loader():
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
if quant_method is not None: if quant_method is not None:
quant_method.process_weights_after_loading(module) quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(woosuk): For accurate performance evaluation, we assign # NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights. # random values to the weights.
......
...@@ -24,7 +24,6 @@ import torch.nn.functional as F ...@@ -24,7 +24,6 @@ import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.server import Runtime from sglang.srt.server import Runtime
from sglang.srt.utils import is_generation_model
DEFAULT_PROMPTS = [ DEFAULT_PROMPTS = [
# the output of gemma-2-2b from SRT is unstable on the commented prompt # the output of gemma-2-2b from SRT is unstable on the commented prompt
...@@ -63,8 +62,8 @@ class HFRunner: ...@@ -63,8 +62,8 @@ class HFRunner:
def __init__( def __init__(
self, self,
model_path, model_path,
torch_dtype=torch.float16, torch_dtype,
is_generation_model=None, is_generation_model,
): ):
self.in_queue = multiprocessing.Queue() self.in_queue = multiprocessing.Queue()
self.out_queue = multiprocessing.Queue() self.out_queue = multiprocessing.Queue()
...@@ -90,11 +89,8 @@ class HFRunner: ...@@ -90,11 +89,8 @@ class HFRunner:
trust_remote_code=True, trust_remote_code=True,
) )
self.is_generation_model = ( self.is_generation_model = is_generation_model
is_generation_model(model_path)
if is_generation_model is None
else is_generation_model
)
if self.is_generation_model: if self.is_generation_model:
self.model = AutoModelForCausalLM.from_pretrained( self.model = AutoModelForCausalLM.from_pretrained(
model_path, model_path,
...@@ -176,16 +172,12 @@ class SRTRunner: ...@@ -176,16 +172,12 @@ class SRTRunner:
def __init__( def __init__(
self, self,
model_path, model_path,
torch_dtype,
is_generation_model,
tp_size=1, tp_size=1,
torch_dtype=torch.float16,
is_generation_model=None,
port=5157, port=5157,
): ):
self.is_generation_model = ( self.is_generation_model = is_generation_model
is_generation_model(model_path)
if is_generation_model is None
else is_generation_model
)
self.runtime = Runtime( self.runtime = Runtime(
model_path=model_path, model_path=model_path,
tp_size=tp_size, tp_size=tp_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment