Unverified Commit f6af3a65 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Cleanup readme, llava examples, usage examples and nccl init (#1194)

parent c9064e6f
......@@ -4,7 +4,7 @@ Usage:
# Installing latest sglang.
# Endpoint Service CLI:
# python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --tokenizer-path lmms-lab/llavanext-qwen-tokenizer --port=30000 --host="127.0.0.1" --tp-size=4
python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8
python3 http_qwen_llava_test.py
......@@ -16,7 +16,6 @@ import argparse
import asyncio
import copy
import json
import time
import aiohttp
import requests
......
"""
Usage: python3 srt_example_llava.py
"""
from PIL import ImageFile
import sglang as sgl
from sglang.lang.chat_template import get_chat_template
from sglang.srt.utils import load_image
ImageFile.LOAD_TRUNCATED_IMAGES = True # Allow loading of truncated images
@sgl.function
def image_qa(s, image, question):
s += sgl.user(sgl.image(image) + question)
s += sgl.assistant(sgl.gen("answer"))
def single():
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
pil_image, _ = load_image(image_url)
state = image_qa.run(image=pil_image, question="What is this?", max_new_tokens=512)
print(state["answer"], "\n")
def stream():
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
pil_image, _ = load_image(image_url)
state = image_qa.run(
image=pil_image,
question="Please generate short caption for this image.",
max_new_tokens=512,
temperature=0,
stream=True,
)
for out in state.text_iter("answer"):
print(out, end="", flush=True)
print()
def batch():
image_url = "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg"
pil_image, _ = load_image(image_url)
states = image_qa.run_batch(
[
{"image": pil_image, "question": "What is this?"},
{"image": pil_image, "question": "What is this?"},
],
max_new_tokens=512,
)
for s in states:
print(s["answer"], "\n")
if __name__ == "__main__":
import multiprocessing as mp
mp.set_start_method("spawn", force=True)
runtime = sgl.Runtime(
model_path="lmms-lab/llama3-llava-next-8b",
tokenizer_path="lmms-lab/llama3-llava-next-8b-tokenizer",
)
runtime.endpoint.chat_template = get_chat_template("llama-3-instruct")
# runtime = sgl.Runtime(
# model_path="lmms-lab/llava-next-72b",
# tokenizer_path="lmms-lab/llavanext-qwen-tokenizer",
# )
# runtime.endpoint.chat_template = get_chat_template("chatml-llava")
sgl.set_default_backend(runtime)
print(f"chat template: {runtime.endpoint.chat_template.name}")
# Or you can use API models
# sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview"))
# sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))
# Run a single request
print("\n========== single ==========\n")
single()
# Stream output
print("\n========== stream ==========\n")
stream()
# Run a batch of requests
print("\n========== batch ==========\n")
batch()
runtime.shutdown()
......@@ -111,7 +111,11 @@ def load_model(server_args, tp_rank):
suppress_other_loggers()
rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
model_config = ModelConfig(path=server_args.model_path)
model_config = ModelConfig(
server_args.model_path,
server_args.trust_remote_code,
context_length=server_args.context_length,
)
model_runner = ModelRunner(
model_config=model_config,
mem_fraction_static=server_args.mem_fraction_static,
......
from dataclasses import dataclass, field
from dataclasses import dataclass
from enum import Enum, auto
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Tuple
class ChatTemplateStyle(Enum):
......
"""Launch the inference server for Llava-video model."""
import argparse
from sglang.srt.server import ServerArgs, launch_server
if __name__ == "__main__":
model_overide_args = {}
model_overide_args["mm_spatial_pool_stride"] = 2
model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
model_overide_args["num_frames"] = 16
model_overide_args["model_type"] = "llavavid"
if model_overide_args["num_frames"] == 32:
model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
model_overide_args["max_sequence_length"] = 4096 * 2
model_overide_args["tokenizer_model_max_length"] = 4096 * 2
model_overide_args["model_max_length"] = 4096 * 2
parser = argparse.ArgumentParser()
ServerArgs.add_cli_args(parser)
args = parser.parse_args()
if "34b" in args.model_path.lower():
model_overide_args["image_token_index"] = 64002
server_args = ServerArgs.from_cli_args(args)
launch_server(server_args, model_overide_args, None)
......@@ -26,7 +26,7 @@ import triton.language as tl
from sglang.srt.managers.schedule_batch import global_server_args_dict
if global_server_args_dict.get("attention_reduce_in_fp32", False):
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32
REDUCE_TORCH_TYPE = torch.float32
else:
......
......@@ -239,7 +239,7 @@ class FusedMoE(torch.nn.Module):
weight_name: str,
shard_id: int,
expert_id: int,
pre_sharded: bool,
use_presharded_weights: bool = False,
):
param_data = param.data
......@@ -273,7 +273,7 @@ class FusedMoE(torch.nn.Module):
else:
tp_rank = get_tensor_model_parallel_rank()
shard_size = self.intermediate_size_per_partition
if pre_sharded:
if use_presharded_weights:
shard = slice(None)
else:
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
......
......@@ -180,7 +180,7 @@ class LogitsProcessor(nn.Module):
if hasattr(self.config, "final_logit_softcapping"):
last_logits.div_(self.config.final_logit_softcapping)
last_logits = torch.tanh(last_logits)
torch.tanh(last_logits, out=last_logits)
last_logits.mul_(self.config.final_logit_softcapping)
# Return only last_logits if logprob is not requested
......@@ -241,7 +241,7 @@ class LogitsProcessor(nn.Module):
if hasattr(self.config, "final_logit_softcapping"):
all_logits.div_(self.config.final_logit_softcapping)
all_logits = torch.tanh(all_logits)
torch.tanh(all_logits, out=all_logits)
all_logits.mul_(self.config.final_logit_softcapping)
all_logprobs = all_logits
......
......@@ -35,7 +35,7 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
global_server_args_dict = {
"disable_flashinfer": False,
"disable_flashinfer_sampling": False,
"attention_reduce_in_fp32": False,
"triton_attention_reduce_in_fp32": False,
"enable_mla": False,
}
......
......@@ -606,6 +606,9 @@ class TokenizerManager:
return background_tasks
def create_handle_loop(self):
if not self.to_create_loop:
return
self.to_create_loop = False
loop = asyncio.get_event_loop()
loop.create_task(self.handle_loop())
......
......@@ -20,7 +20,6 @@ import importlib
import importlib.resources
import logging
import pkgutil
import warnings
from functools import lru_cache
from typing import Optional, Type
......@@ -91,23 +90,35 @@ class ModelRunner:
{
"disable_flashinfer": server_args.disable_flashinfer,
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
"enable_mla": server_args.enable_mla,
}
)
min_per_gpu_memory = self.init_torch_distributed()
self.load_model()
self.init_memory_pool(
min_per_gpu_memory,
server_args.max_num_reqs,
server_args.max_total_tokens,
)
self.init_cublas()
self.init_flashinfer()
self.init_cuda_graphs()
def init_torch_distributed(self):
# Init torch distributed
torch.cuda.set_device(self.gpu_id)
logger.info(f"[gpu={self.gpu_id}] Init nccl begin.")
if not server_args.enable_p2p_check:
if not self.server_args.enable_p2p_check:
monkey_patch_vllm_p2p_access_check(self.gpu_id)
if server_args.nccl_init_addr:
nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
if self.server_args.nccl_init_addr:
nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
else:
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
set_custom_all_reduce(not server_args.disable_custom_all_reduce)
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
init_distributed_environment(
backend="nccl",
world_size=self.tp_size,
......@@ -116,32 +127,28 @@ class ModelRunner:
distributed_init_method=nccl_init_method,
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
total_gpu_memory = get_available_gpu_memory(
min_per_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
self.tp_group = get_tp_group()
# Currently, there is a bug with mulit-node tensor parallelsim + padded cuda graph,
# so we disable padding in cuda graph.
if not all(in_the_same_node_as(self.tp_group.cpu_group, source_rank=0)):
self.server_args.disable_cuda_graph_padding = True
logger.info(
"Setting disable_cuda_graph_padding to True because of multi-node tensor parallelism."
)
# Check memory for tensor parallelism
if self.tp_size > 1:
total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
if total_local_gpu_memory < total_gpu_memory * 0.9:
local_gpu_memory = get_available_gpu_memory(self.gpu_id)
if min_per_gpu_memory < local_gpu_memory * 0.9:
raise ValueError(
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
)
# Load the model and create memory pool
self.load_model()
self.init_memory_pool(
total_gpu_memory,
server_args.max_num_reqs,
server_args.max_total_tokens,
)
self.init_cublas()
self.init_flashinfer()
if self.is_generation:
# FIXME Currently, cuda graph only capture decode steps, which only exists in causal models
# Capture cuda graphs
self.init_cuda_graphs()
return min_per_gpu_memory
def load_model(self):
logger.info(
......@@ -150,7 +157,7 @@ class ModelRunner:
)
if torch.cuda.get_device_capability()[0] < 8:
logger.info(
"Compute capability below sm80 use float16 due to lack of bfloat16 support."
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
)
self.server_args.dtype = "float16"
......@@ -168,8 +175,9 @@ class ModelRunner:
skip_tokenizer_init=True,
)
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
# Drop this after Sept, 2024.
if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
# A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
self.model_config.hf_config.num_key_value_heads = 8
self.vllm_model_config.hf_config.num_key_value_heads = 8
monkey_patch_vllm_qvk_linear_loader()
......@@ -191,8 +199,8 @@ class ModelRunner:
cache_config=None,
)
self.sliding_window_size = (
self.model.get_window_size()
if hasattr(self.model, "get_window_size")
self.model.get_attention_sliding_window_size()
if hasattr(self.model, "get_attention_sliding_window_size")
else None
)
self.is_generation = is_generation_model(
......@@ -206,7 +214,8 @@ class ModelRunner:
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
)
def update_weights(self, model_path, load_format):
def update_weights(self, model_path: str, load_format: str):
"""Update weights in-place."""
from vllm.model_executor.model_loader.loader import (
DefaultModelLoader,
device_loading_context,
......@@ -222,6 +231,7 @@ class ModelRunner:
target_device = torch.device(self.device_config.device)
try:
# TODO: Use a better method to check this
vllm_model_config = VllmModelConfig(
model=model_path,
quantization=self.server_args.quantization,
......@@ -291,7 +301,7 @@ class ModelRunner:
logger.info(f"[gpu={self.gpu_id}] Update weights end.")
return True, "Succeeded to update model weights"
def profile_max_num_token(self, total_gpu_memory):
def profile_max_num_token(self, total_gpu_memory: int):
available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
......@@ -319,7 +329,10 @@ class ModelRunner:
return max_num_token
def init_memory_pool(
self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
self,
total_gpu_memory: int,
max_num_reqs: int = None,
max_total_tokens: int = None,
):
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
if max_total_tokens is not None:
......@@ -388,6 +401,7 @@ class ModelRunner:
return c
def init_flashinfer(self):
"""Init flashinfer attention kernel wrappers."""
if self.server_args.disable_flashinfer:
assert (
self.sliding_window_size is None
......@@ -448,6 +462,11 @@ class ModelRunner:
)
def init_cuda_graphs(self):
"""Capture cuda graphs."""
if not self.is_generation:
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
return
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
......@@ -457,7 +476,12 @@ class ModelRunner:
logger.info(
f"[gpu={self.gpu_id}] Capture cuda graph begin. This can take up to several minutes."
)
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
if self.server_args.disable_cuda_graph_padding:
batch_size_list = list(range(1, 32)) + [64, 128]
else:
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.cuda_graph_runner = CudaGraphRunner(
self,
max_batch_size_to_capture=max(batch_size_list),
......
......@@ -46,7 +46,7 @@ from sglang.srt.model_executor.forward_batch_info import InputMetadata
# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
def get_window_size(config):
def get_attention_sliding_window_size(config):
return config.sliding_window - 1
......@@ -213,7 +213,11 @@ class Gemma2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_idx,
sliding_window_size=get_window_size(config) if use_sliding_window else None,
sliding_window_size=(
get_attention_sliding_window_size(config)
if use_sliding_window
else None
),
logit_cap=self.config.attn_logit_softcapping,
)
......@@ -406,8 +410,8 @@ class Gemma2ForCausalLM(nn.Module):
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
)
def get_window_size(self):
return get_window_size(self.config)
def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -295,12 +295,14 @@ class Grok1ModelForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = Grok1Model(config, quant_config=quant_config)
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.lm_head = ReplicatedLinear(config.hidden_size, config.vocab_size)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
# Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
self.use_presharded_weights = True
warnings.filterwarnings("ignore", category=FutureWarning)
def forward(
......@@ -356,6 +358,13 @@ class Grok1ModelForCausalLM(nn.Module):
continue
name = name.replace(weight_name, param_name)
if self.use_presharded_weights:
extra_kwargs = {
"use_presharded_weights": self.use_presharded_weights
}
else:
extra_kwargs = {}
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
......@@ -364,7 +373,7 @@ class Grok1ModelForCausalLM(nn.Module):
weight_name,
shard_id=shard_id,
expert_id=expert_id,
pre_sharded=get_tensor_model_parallel_world_size() > 1,
**extra_kwargs,
)
break
else:
......
......@@ -81,13 +81,12 @@ class ServerArgs:
disable_cuda_graph: bool = False
disable_cuda_graph_padding: bool = False
disable_disk_cache: bool = False
disable_custom_all_reduce: bool = False
enable_mixed_chunk: bool = False
enable_torch_compile: bool = False
enable_p2p_check: bool = False
enable_mla: bool = False
attention_reduce_in_fp32: bool = False
efficient_weight_load: bool = False
disable_custom_all_reduce: bool = False
triton_attention_reduce_in_fp32: bool = False
# Distributed args
nccl_init_addr: Optional[str] = None
......@@ -404,6 +403,12 @@ class ServerArgs:
action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
)
parser.add_argument(
"--disable-custom-all-reduce",
action="store_true",
default=False,
help="Disable the custom all-reduce kernel and fall back to NCCL.",
)
parser.add_argument(
"--enable-mixed-chunk",
action="store_true",
......@@ -425,7 +430,7 @@ class ServerArgs:
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2.",
)
parser.add_argument(
"--attention-reduce-in-fp32",
"--triton-attention-reduce-in-fp32",
action="store_true",
help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16."
"This only affects Triton attention kernels.",
......@@ -435,12 +440,6 @@ class ServerArgs:
action="store_true",
help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).",
)
parser.add_argument(
"--disable-custom-all-reduce",
action="store_true",
default=False,
help="Disable the custom all-reduce kernel and fall back to NCCL.",
)
@classmethod
def from_cli_args(cls, args: argparse.Namespace):
......
......@@ -347,7 +347,7 @@ def suppress_other_loggers():
logging.WARN
)
logging.getLogger("vllm.selector").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.WARN)
logging.getLogger("vllm.utils").setLevel(logging.ERROR)
def assert_pkg_version(pkg: str, min_version: str, message: str):
......@@ -451,10 +451,6 @@ def monkey_patch_vllm_dummy_weight_loader():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
if hasattr(module, "process_weights_after_loading"):
module.process_weights_after_loading()
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
......
......@@ -24,7 +24,6 @@ import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.server import Runtime
from sglang.srt.utils import is_generation_model
DEFAULT_PROMPTS = [
# the output of gemma-2-2b from SRT is unstable on the commented prompt
......@@ -63,8 +62,8 @@ class HFRunner:
def __init__(
self,
model_path,
torch_dtype=torch.float16,
is_generation_model=None,
torch_dtype,
is_generation_model,
):
self.in_queue = multiprocessing.Queue()
self.out_queue = multiprocessing.Queue()
......@@ -90,11 +89,8 @@ class HFRunner:
trust_remote_code=True,
)
self.is_generation_model = (
is_generation_model(model_path)
if is_generation_model is None
else is_generation_model
)
self.is_generation_model = is_generation_model
if self.is_generation_model:
self.model = AutoModelForCausalLM.from_pretrained(
model_path,
......@@ -176,16 +172,12 @@ class SRTRunner:
def __init__(
self,
model_path,
torch_dtype,
is_generation_model,
tp_size=1,
torch_dtype=torch.float16,
is_generation_model=None,
port=5157,
):
self.is_generation_model = (
is_generation_model(model_path)
if is_generation_model is None
else is_generation_model
)
self.is_generation_model = is_generation_model
self.runtime = Runtime(
model_path=model_path,
tp_size=tp_size,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment