"graphbolt/vscode:/vscode.git/clone" did not exist on "d2eca8557dc969d672bfaf823677dfed6dc7dcb5"
Unverified Commit fb9296f0 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Higher priority for user input of max_prefill_tokens & format (#540)

parent 1374334d
...@@ -65,7 +65,7 @@ def main(args): ...@@ -65,7 +65,7 @@ def main(args):
def get_one_answer(i): def get_one_answer(i):
answer = call_generate( answer = call_generate(
prompt=few_shot_examples + questions[i], prompt=few_shot_examples + questions[i],
#prompt="System: " + few_shot_examples + "<|separator|>\n\n" + questions[i], # prompt="System: " + few_shot_examples + "<|separator|>\n\n" + questions[i],
temperature=0, temperature=0,
max_tokens=256, max_tokens=256,
stop="Question", stop="Question",
......
...@@ -158,7 +158,9 @@ async def send_request( ...@@ -158,7 +158,9 @@ async def send_request(
timeout = aiohttp.ClientTimeout(total=3 * 3600) timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout) as session: async with aiohttp.ClientSession(timeout=timeout) as session:
while True: while True:
async with session.post(api_url, headers=headers, json=pload) as response: async with session.post(
api_url, headers=headers, json=pload
) as response:
chunks = [] chunks = []
async for chunk, _ in response.content.iter_chunks(): async for chunk, _ in response.content.iter_chunks():
chunks.append(chunk) chunks.append(chunk)
...@@ -228,19 +230,32 @@ def main(args: argparse.Namespace): ...@@ -228,19 +230,32 @@ def main(args: argparse.Namespace):
np.random.seed(args.seed) np.random.seed(args.seed)
api_url = f"http://{args.host}:{args.port}/generate" api_url = f"http://{args.host}:{args.port}/generate"
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=args.trust_remote_code) tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code
)
if args.dataset: if args.dataset:
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
else: else:
input_lens = np.random.randint( input_lens = np.random.randint(
int(args.input_len * args.range_ratio), args.input_len + 1, size=args.num_prompts) int(args.input_len * args.range_ratio),
args.input_len + 1,
size=args.num_prompts,
)
output_lens = np.random.randint( output_lens = np.random.randint(
int(args.output_len * args.range_ratio), args.output_len + 1, size=args.num_prompts) int(args.output_len * args.range_ratio),
args.output_len + 1,
size=args.num_prompts,
)
offsets = np.random.randint(0, tokenizer.vocab_size, size=args.num_prompts) offsets = np.random.randint(0, tokenizer.vocab_size, size=args.num_prompts)
input_requests = [] input_requests = []
for i in range(args.num_prompts): for i in range(args.num_prompts):
prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]) prompt = tokenizer.decode(
[
(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])
]
)
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
benchmark_start_time = time.perf_counter() benchmark_start_time = time.perf_counter()
...@@ -287,16 +302,15 @@ if __name__ == "__main__": ...@@ -287,16 +302,15 @@ if __name__ == "__main__":
) )
parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=30000) parser.add_argument("--port", type=int, default=30000)
parser.add_argument( parser.add_argument("--dataset", type=str, help="Path to the dataset.")
"--dataset", type=str, help="Path to the dataset."
)
parser.add_argument("--input-len", type=int, default=2048) parser.add_argument("--input-len", type=int, default=2048)
parser.add_argument("--output-len", type=int, default=256) parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--range-ratio", type=float, default=1.0) parser.add_argument("--range-ratio", type=float, default=1.0)
parser.add_argument( parser.add_argument(
"--tokenizer", type=str, "--tokenizer",
type=str,
default="NousResearch/Meta-Llama-3-8B", default="NousResearch/Meta-Llama-3-8B",
help="Name or path of the tokenizer." help="Name or path of the tokenizer.",
) )
parser.add_argument( parser.add_argument(
"--best-of", "--best-of",
......
...@@ -24,10 +24,10 @@ from sglang.api import ( ...@@ -24,10 +24,10 @@ from sglang.api import (
# SGL Backends # SGL Backends
from sglang.backend.anthropic import Anthropic from sglang.backend.anthropic import Anthropic
from sglang.backend.litellm import LiteLLM
from sglang.backend.openai import OpenAI from sglang.backend.openai import OpenAI
from sglang.backend.runtime_endpoint import RuntimeEndpoint from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.backend.vertexai import VertexAI from sglang.backend.vertexai import VertexAI
from sglang.backend.litellm import LiteLLM
# Global Configurations # Global Configurations
from sglang.global_config import global_config from sglang.global_config import global_config
......
...@@ -33,7 +33,8 @@ class LiteLLM(BaseBackend): ...@@ -33,7 +33,8 @@ class LiteLLM(BaseBackend):
self.model_name = model_name self.model_name = model_name
self.chat_template = chat_template or get_chat_template_by_model_path( self.chat_template = chat_template or get_chat_template_by_model_path(
model_name) model_name
)
self.client_params = { self.client_params = {
"api_key": api_key, "api_key": api_key,
......
import dataclasses
import logging import logging
import time import time
import warnings import warnings
import dataclasses
from typing import Callable, List, Optional, Union from typing import Callable, List, Optional, Union
import numpy as np import numpy as np
...@@ -105,14 +105,16 @@ class OpenAI(BaseBackend): ...@@ -105,14 +105,16 @@ class OpenAI(BaseBackend):
def get_chat_template(self): def get_chat_template(self):
return self.chat_template return self.chat_template
def _prepare_spec_execution(self, sampling_params: SglSamplingParams, def _prepare_spec_execution(
num_api_spec_tokens: int, spec_var_name: str): self,
sampling_params: SglSamplingParams,
num_api_spec_tokens: int,
spec_var_name: str,
):
if "max_tokens" not in self.spec_kwargs: if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = num_api_spec_tokens self.spec_kwargs["max_tokens"] = num_api_spec_tokens
else: else:
assert ( assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
self.spec_kwargs["max_tokens"] == num_api_spec_tokens
)
params = sampling_params.to_openai_kwargs() params = sampling_params.to_openai_kwargs()
for key, value in params.items(): for key, value in params.items():
...@@ -151,8 +153,9 @@ class OpenAI(BaseBackend): ...@@ -151,8 +153,9 @@ class OpenAI(BaseBackend):
) )
prompt = s.messages_ prompt = s.messages_
else: else:
return self._prepare_spec_execution(sampling_params, return self._prepare_spec_execution(
s.num_api_spec_tokens, spec_var_name) sampling_params, s.num_api_spec_tokens, spec_var_name
)
else: else:
prompt = s.text_ prompt = s.text_
...@@ -325,7 +328,7 @@ class OpenAI(BaseBackend): ...@@ -325,7 +328,7 @@ class OpenAI(BaseBackend):
ret_str = ret.choices[0].text ret_str = ret.choices[0].text
ret_token = self.tokenizer.encode(ret_str)[0] ret_token = self.tokenizer.encode(ret_str)[0]
self.token_usage.prompt_tokens += ret.usage.prompt_tokens self.token_usage.prompt_tokens += ret.usage.prompt_tokens
self.token_usage.completion_tokens= ret.usage.completion_tokens self.token_usage.completion_tokens = ret.usage.completion_tokens
# TODO: # TODO:
# 1. return logits as the scores # 1. return logits as the scores
...@@ -355,7 +358,9 @@ class OpenAI(BaseBackend): ...@@ -355,7 +358,9 @@ class OpenAI(BaseBackend):
return decision, scores, None, None return decision, scores, None, None
def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs): def openai_completion(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
for attempt in range(retries): for attempt in range(retries):
try: try:
if is_chat: if is_chat:
...@@ -385,15 +390,19 @@ def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, ...@@ -385,15 +390,19 @@ def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None,
return comp return comp
def openai_completion_stream(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs): def openai_completion_stream(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
for attempt in range(retries): for attempt in range(retries):
try: try:
if is_chat: if is_chat:
if "stop" in kwargs and kwargs["stop"] is None: if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop") kwargs.pop("stop")
generator = client.chat.completions.create( generator = client.chat.completions.create(
messages=prompt, stream=True, stream_options={"include_usage": True}, messages=prompt,
**kwargs stream=True,
stream_options={"include_usage": True},
**kwargs,
) )
for ret in generator: for ret in generator:
if len(ret.choices) == 0: if len(ret.choices) == 0:
...@@ -405,8 +414,10 @@ def openai_completion_stream(client, token_usage, is_chat=None, retries=3, promp ...@@ -405,8 +414,10 @@ def openai_completion_stream(client, token_usage, is_chat=None, retries=3, promp
yield content or "", {} yield content or "", {}
else: else:
generator = client.completions.create( generator = client.completions.create(
prompt=prompt, stream=True, stream_options={"include_usage": True}, prompt=prompt,
**kwargs stream=True,
stream_options={"include_usage": True},
**kwargs,
) )
for ret in generator: for ret in generator:
if len(ret.choices) == 0: if len(ret.choices) == 0:
......
...@@ -84,9 +84,7 @@ class SglSamplingParams: ...@@ -84,9 +84,7 @@ class SglSamplingParams:
def to_litellm_kwargs(self): def to_litellm_kwargs(self):
if self.regex is not None: if self.regex is not None:
warnings.warn( warnings.warn("Regular expression is not supported in the LiteLLM backend.")
"Regular expression is not supported in the LiteLLM backend."
)
return { return {
"max_tokens": self.max_new_tokens, "max_tokens": self.max_new_tokens,
"stop": self.stop or None, "stop": self.stop or None,
......
"""Launch the inference server for Llava-video model.""" """Launch the inference server for Llava-video model."""
import argparse import argparse
import multiprocessing as mp import multiprocessing as mp
......
...@@ -4,7 +4,7 @@ from typing import Dict, Optional, Union ...@@ -4,7 +4,7 @@ from typing import Dict, Optional, Union
from outlines.caching import cache as disk_cache from outlines.caching import cache as disk_cache
from outlines.caching import disable_cache from outlines.caching import disable_cache
from outlines.fsm.guide import RegexGuide from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm, make_byte_level_fsm from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from outlines.models.transformers import TransformerTokenizer from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel from pydantic import BaseModel
......
"""Cache for the compressed finite state machine.""" """Cache for the compressed finite state machine."""
from sglang.srt.constrained import RegexGuide, TransformerTokenizer from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_cache import BaseCache from sglang.srt.constrained.base_cache import BaseCache
......
...@@ -8,11 +8,12 @@ from collections import defaultdict ...@@ -8,11 +8,12 @@ from collections import defaultdict
import interegular import interegular
import outlines.caching import outlines.caching
from sglang.srt.constrained import ( from sglang.srt.constrained import (
FSMInfo, FSMInfo,
disk_cache, disk_cache,
make_deterministic_fsm,
make_byte_level_fsm, make_byte_level_fsm,
make_deterministic_fsm,
) )
from sglang.srt.constrained.base_cache import BaseCache from sglang.srt.constrained.base_cache import BaseCache
......
"""Conversation templates.""" """Conversation templates."""
# Adapted from # Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses import dataclasses
......
"""Utilities for Huggingface Transformers.""" """Utilities for Huggingface Transformers."""
import functools
import json import json
import os import os
import warnings import warnings
import functools from typing import AbstractSet, Collection, Literal, Optional, Union
from typing import Optional, Union, AbstractSet, Collection, Literal
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from transformers import ( from transformers import (
...@@ -179,6 +179,7 @@ def get_processor( ...@@ -179,6 +179,7 @@ def get_processor(
class TiktokenTokenizer: class TiktokenTokenizer:
def __init__(self, tokenizer_path): def __init__(self, tokenizer_path):
import tiktoken import tiktoken
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+""" PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
# Read JSON # Read JSON
...@@ -190,7 +191,8 @@ class TiktokenTokenizer: ...@@ -190,7 +191,8 @@ class TiktokenTokenizer:
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"] bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
} }
special_tokens = { special_tokens = {
bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"] bytes(item["bytes"]).decode(): item["token"]
for item in tok_dict["special_tokens"]
} }
assert tok_dict["word_split"] == "V1" assert tok_dict["word_split"] == "V1"
...@@ -202,7 +204,10 @@ class TiktokenTokenizer: ...@@ -202,7 +204,10 @@ class TiktokenTokenizer:
} }
if "default_allowed_special" in tok_dict: if "default_allowed_special" in tok_dict:
default_allowed_special = set( default_allowed_special = set(
[bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]] [
bytes(bytes_list).decode()
for bytes_list in tok_dict["default_allowed_special"]
]
) )
else: else:
default_allowed_special = None default_allowed_special = None
...@@ -216,14 +221,20 @@ class TiktokenTokenizer: ...@@ -216,14 +221,20 @@ class TiktokenTokenizer:
self, self,
text: str, text: str,
*, *,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006 allowed_special: Union[
Literal["all"], AbstractSet[str]
] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all", disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> list[int]: ) -> list[int]:
if isinstance(allowed_special, set): if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode( return tiktoken.Encoding.encode(
self, text, allowed_special=allowed_special, disallowed_special=disallowed_special self,
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
) )
tokenizer.encode = functools.partial(encode_patched, tokenizer) tokenizer.encode = functools.partial(encode_patched, tokenizer)
# Convert to HF interface # Convert to HF interface
...@@ -237,10 +248,14 @@ class TiktokenTokenizer: ...@@ -237,10 +248,14 @@ class TiktokenTokenizer:
def decode(self, x): def decode(self, x):
return self.tokenizer.decode(x) return self.tokenizer.decode(x)
def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False): def batch_decode(
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
):
if isinstance(batch[0], int): if isinstance(batch[0], int):
batch = [[x] for x in batch] batch = [[x] for x in batch]
return self.tokenizer.decode_batch(batch) return self.tokenizer.decode_batch(batch)
def convert_ids_to_tokens(self, index): def convert_ids_to_tokens(self, index):
return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore") return self.tokenizer.decode_single_token_bytes(index).decode(
\ No newline at end of file "utf-8", errors="ignore"
)
...@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple ...@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_hip from vllm.utils import is_hip
...@@ -109,12 +108,16 @@ def fused_moe_kernel( ...@@ -109,12 +108,16 @@ def fused_moe_kernel(
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + a_ptrs = a_ptr + (
offs_k[None, :] * stride_ak) offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
off_experts = tl.load(expert_ids_ptr + pid_m) off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + b_ptrs = (
offs_bn[None, :] * stride_bn) b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
if use_fp8: if use_fp8:
a_scale = tl.load(a_scale_ptr) a_scale = tl.load(a_scale_ptr)
...@@ -130,13 +133,12 @@ def fused_moe_kernel( ...@@ -130,13 +133,12 @@ def fused_moe_kernel(
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the # Load the next block of A and B, generate a mask by checking the
# K dimension. # K dimension.
a = tl.load(a_ptrs, a = tl.load(
mask=token_mask[:, None] & a_ptrs,
(offs_k[None, :] < K - k * BLOCK_SIZE_K), mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0) other=0.0,
b = tl.load(b_ptrs, )
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
other=0.0)
# We accumulate along the K dimension. # We accumulate along the K dimension.
if use_fp8: if use_fp8:
accumulator = tl.dot(a, b, acc=accumulator) accumulator = tl.dot(a, b, acc=accumulator)
...@@ -147,9 +149,7 @@ def fused_moe_kernel( ...@@ -147,9 +149,7 @@ def fused_moe_kernel(
b_ptrs += BLOCK_SIZE_K * stride_bk b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT: if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None] accumulator = accumulator * moe_weight[:, None]
if use_fp8: if use_fp8:
...@@ -159,15 +159,14 @@ def fused_moe_kernel( ...@@ -159,15 +159,14 @@ def fused_moe_kernel(
# ----------------------------------------------------------- # -----------------------------------------------------------
# Write back the block of the output # Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N) c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
def moe_align_block_size( def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int, topk_ids: torch.Tensor, block_size: int, num_experts: int
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block Aligns the token distribution across experts to be compatible with block
size for matrix multiplication. size for matrix multiplication.
...@@ -206,32 +205,38 @@ def moe_align_block_size( ...@@ -206,32 +205,38 @@ def moe_align_block_size(
by block_size for proper block matrix operations. by block_size for proper block matrix operations.
""" """
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ), sorted_ids = torch.empty(
dtype=torch.int32, (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
device=topk_ids.device) )
sorted_ids.fill_(topk_ids.numel()) sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty((max_num_m_blocks, ), expert_ids = torch.empty(
dtype=torch.int32, (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
device=topk_ids.device) )
num_tokens_post_pad = torch.empty((1), num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
dtype=torch.int32, ops.moe_align_block_size(
device=topk_ids.device) topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, )
expert_ids, num_tokens_post_pad)
return sorted_ids, expert_ids, num_tokens_post_pad return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, def invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale: Optional[torch.Tensor], A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_ids: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int, mul_routed_weight: bool,
config: Dict[str, Any], compute_type: tl.dtype, top_k: int,
use_fp8: bool) -> None: config: Dict[str, Any],
compute_type: tl.dtype,
use_fp8: bool,
) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
...@@ -242,8 +247,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ...@@ -242,8 +247,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A, A_scale = ops.scaled_fp8_quant(A, A_scale) A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None assert B_scale is not None
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ grid = lambda META: (
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
)
fused_moe_kernel[grid]( fused_moe_kernel[grid](
A, A,
...@@ -281,8 +288,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: ...@@ -281,8 +288,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
@functools.lru_cache @functools.lru_cache
def get_moe_configs(E: int, N: int, def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
""" """
Return optimized configurations for the fused MoE kernel. Return optimized configurations for the fused MoE kernel.
...@@ -297,11 +303,11 @@ def get_moe_configs(E: int, N: int, ...@@ -297,11 +303,11 @@ def get_moe_configs(E: int, N: int,
json_file_name = get_config_file_name(E, N, dtype) json_file_name = get_config_file_name(E, N, dtype)
config_file_path = os.path.join( config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
with open(config_file_path) as f: with open(config_file_path) as f:
logger.info("Using configuration from %s for MoE layer.", logger.info("Using configuration from %s for MoE layer.", config_file_path)
config_file_path)
# If a configuration has been found, return it # If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()} return {int(key): val for key, val in json.load(f).items()}
...@@ -352,40 +358,30 @@ def fused_moe( ...@@ -352,40 +358,30 @@ def fused_moe(
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
""" """
# Check constraints. # Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], ( assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
"Number of tokens mismatch")
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [ assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
torch.float32, torch.float16, torch.bfloat16
]
M, _ = hidden_states.shape M, _ = hidden_states.shape
E, N, _ = w1.shape E, N, _ = w1.shape
if is_hip(): if is_hip():
# The MoE kernels are not yet supported on ROCm. # The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output, routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32)
dim=-1,
dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else: else:
import vllm._moe_C as moe_kernels import vllm._moe_C as moe_kernels
topk_weights = torch.empty(M, topk_weights = torch.empty(
topk, M, topk, dtype=torch.float32, device=hidden_states.device
dtype=torch.float32, )
device=hidden_states.device) topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
topk_ids = torch.empty(M, token_expert_indicies = torch.empty(
topk, M, topk, dtype=torch.int32, device=hidden_states.device
dtype=torch.int32, )
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
moe_kernels.topk_softmax( moe_kernels.topk_softmax(
topk_weights, topk_weights,
topk_ids, topk_ids,
...@@ -400,8 +396,7 @@ def fused_moe( ...@@ -400,8 +396,7 @@ def fused_moe(
config = override_config config = override_config
else: else:
# First try to load optimal config from the file # First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2], configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
"float8" if use_fp8 else None)
if configs: if configs:
# If an optimal configuration map has been found, look up the # If an optimal configuration map has been found, look up the
...@@ -415,7 +410,7 @@ def fused_moe( ...@@ -415,7 +410,7 @@ def fused_moe(
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 4 "num_stages": 4,
} }
if M <= E: if M <= E:
...@@ -425,25 +420,32 @@ def fused_moe( ...@@ -425,25 +420,32 @@ def fused_moe(
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16, "GROUP_SIZE_M": 16,
"num_warps": 8, "num_warps": 8,
"num_stages": 4 "num_stages": 4,
} }
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype,
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), )
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype,
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), )
intermediate_cache3 = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype,
)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E) topk_ids, config["BLOCK_SIZE_M"], E
compute_type = (tl.bfloat16 )
if hidden_states.dtype == torch.bfloat16 else tl.float16) compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
invoke_fused_moe_kernel(hidden_states, invoke_fused_moe_kernel(
hidden_states,
w1, w1,
intermediate_cache1, intermediate_cache1,
a1_scale, a1_scale,
...@@ -457,11 +459,13 @@ def fused_moe( ...@@ -457,11 +459,13 @@ def fused_moe(
topk_ids.shape[1], topk_ids.shape[1],
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8=use_fp8) use_fp8=use_fp8,
)
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2, invoke_fused_moe_kernel(
intermediate_cache2,
w2, w2,
intermediate_cache3, intermediate_cache3,
a2_scale, a2_scale,
...@@ -475,11 +479,13 @@ def fused_moe( ...@@ -475,11 +479,13 @@ def fused_moe(
1, 1,
config, config,
compute_type=compute_type, compute_type=compute_type,
use_fp8=use_fp8) use_fp8=use_fp8,
)
if inplace: if inplace:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), return torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1, dim=1,
out=hidden_states) out=hidden_states,
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), )
dim=1) return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
\ No newline at end of file
"""Logits processing.""" """Logits processing."""
import torch import torch
from torch import nn from torch import nn
from vllm.distributed import ( from vllm.distributed import (
......
"""Radix attention.""" """Radix attention."""
import torch
import numpy as np import numpy as np
import torch
from torch import nn from torch import nn
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
...@@ -10,7 +11,9 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada ...@@ -10,7 +11,9 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
class RadixAttention(nn.Module): class RadixAttention(nn.Module):
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1): def __init__(
self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1
):
super().__init__() super().__init__()
self.tp_q_head_num = num_heads self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads self.tp_k_head_num = num_kv_heads
......
...@@ -4,7 +4,7 @@ import asyncio ...@@ -4,7 +4,7 @@ import asyncio
import logging import logging
import queue import queue
import threading import threading
from typing import List, Callable from typing import Callable, List
import uvloop import uvloop
import zmq import zmq
...@@ -70,7 +70,9 @@ class DataParallelWorkerThread(threading.Thread): ...@@ -70,7 +70,9 @@ class DataParallelWorkerThread(threading.Thread):
# async sleep for receiving the subsequent request and avoiding cache miss # async sleep for receiving the subsequent request and avoiding cache miss
if len(out_pyobjs) != 0: if len(out_pyobjs) != 0:
has_finished = any([obj.finished_reason is not None for obj in out_pyobjs]) has_finished = any(
[obj.finished_reason is not None for obj in out_pyobjs]
)
if has_finished: if has_finished:
await asyncio.sleep(self.request_dependency_delay) await asyncio.sleep(self.request_dependency_delay)
await asyncio.sleep(global_config.wait_for_new_request_delay) await asyncio.sleep(global_config.wait_for_new_request_delay)
......
"""Meta data for requests and batches""" """Meta data for requests and batches"""
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import List from typing import List
import warnings
import numpy as np import numpy as np
import torch import torch
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.managers.controller.radix_cache import RadixCache from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.constrained import RegexGuide
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment