"android/vscode:/vscode.git/clone" did not exist on "e1c49fafa7e077c85b4a8cfe1e18ccddeb853959"
Unverified Commit fb9296f0 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Higher priority for user input of max_prefill_tokens & format (#540)

parent 1374334d
......@@ -65,7 +65,7 @@ def main(args):
def get_one_answer(i):
answer = call_generate(
prompt=few_shot_examples + questions[i],
#prompt="System: " + few_shot_examples + "<|separator|>\n\n" + questions[i],
# prompt="System: " + few_shot_examples + "<|separator|>\n\n" + questions[i],
temperature=0,
max_tokens=256,
stop="Question",
......
......@@ -158,7 +158,9 @@ async def send_request(
timeout = aiohttp.ClientTimeout(total=3 * 3600)
async with aiohttp.ClientSession(timeout=timeout) as session:
while True:
async with session.post(api_url, headers=headers, json=pload) as response:
async with session.post(
api_url, headers=headers, json=pload
) as response:
chunks = []
async for chunk, _ in response.content.iter_chunks():
chunks.append(chunk)
......@@ -228,19 +230,32 @@ def main(args: argparse.Namespace):
np.random.seed(args.seed)
api_url = f"http://{args.host}:{args.port}/generate"
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=args.trust_remote_code)
tokenizer = AutoTokenizer.from_pretrained(
args.tokenizer, trust_remote_code=args.trust_remote_code
)
if args.dataset:
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
else:
input_lens = np.random.randint(
int(args.input_len * args.range_ratio), args.input_len + 1, size=args.num_prompts)
int(args.input_len * args.range_ratio),
args.input_len + 1,
size=args.num_prompts,
)
output_lens = np.random.randint(
int(args.output_len * args.range_ratio), args.output_len + 1, size=args.num_prompts)
int(args.output_len * args.range_ratio),
args.output_len + 1,
size=args.num_prompts,
)
offsets = np.random.randint(0, tokenizer.vocab_size, size=args.num_prompts)
input_requests = []
for i in range(args.num_prompts):
prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])])
prompt = tokenizer.decode(
[
(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])
]
)
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
benchmark_start_time = time.perf_counter()
......@@ -287,16 +302,15 @@ if __name__ == "__main__":
)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=30000)
parser.add_argument(
"--dataset", type=str, help="Path to the dataset."
)
parser.add_argument("--dataset", type=str, help="Path to the dataset.")
parser.add_argument("--input-len", type=int, default=2048)
parser.add_argument("--output-len", type=int, default=256)
parser.add_argument("--range-ratio", type=float, default=1.0)
parser.add_argument(
"--tokenizer", type=str,
"--tokenizer",
type=str,
default="NousResearch/Meta-Llama-3-8B",
help="Name or path of the tokenizer."
help="Name or path of the tokenizer.",
)
parser.add_argument(
"--best-of",
......
......@@ -170,4 +170,4 @@ if __name__ == "__main__":
parser.add_argument("--data_dir", type=str, default="data")
parser.add_argument("--nsub", type=int, default=60)
args = add_common_other_args_and_parse(parser)
main(args)
\ No newline at end of file
main(args)
......@@ -24,10 +24,10 @@ from sglang.api import (
# SGL Backends
from sglang.backend.anthropic import Anthropic
from sglang.backend.litellm import LiteLLM
from sglang.backend.openai import OpenAI
from sglang.backend.runtime_endpoint import RuntimeEndpoint
from sglang.backend.vertexai import VertexAI
from sglang.backend.litellm import LiteLLM
# Global Configurations
from sglang.global_config import global_config
......
......@@ -33,7 +33,8 @@ class LiteLLM(BaseBackend):
self.model_name = model_name
self.chat_template = chat_template or get_chat_template_by_model_path(
model_name)
model_name
)
self.client_params = {
"api_key": api_key,
......
import dataclasses
import logging
import time
import warnings
import dataclasses
from typing import Callable, List, Optional, Union
import numpy as np
......@@ -105,14 +105,16 @@ class OpenAI(BaseBackend):
def get_chat_template(self):
return self.chat_template
def _prepare_spec_execution(self, sampling_params: SglSamplingParams,
num_api_spec_tokens: int, spec_var_name: str):
def _prepare_spec_execution(
self,
sampling_params: SglSamplingParams,
num_api_spec_tokens: int,
spec_var_name: str,
):
if "max_tokens" not in self.spec_kwargs:
self.spec_kwargs["max_tokens"] = num_api_spec_tokens
else:
assert (
self.spec_kwargs["max_tokens"] == num_api_spec_tokens
)
assert self.spec_kwargs["max_tokens"] == num_api_spec_tokens
params = sampling_params.to_openai_kwargs()
for key, value in params.items():
......@@ -151,8 +153,9 @@ class OpenAI(BaseBackend):
)
prompt = s.messages_
else:
return self._prepare_spec_execution(sampling_params,
s.num_api_spec_tokens, spec_var_name)
return self._prepare_spec_execution(
sampling_params, s.num_api_spec_tokens, spec_var_name
)
else:
prompt = s.text_
......@@ -325,7 +328,7 @@ class OpenAI(BaseBackend):
ret_str = ret.choices[0].text
ret_token = self.tokenizer.encode(ret_str)[0]
self.token_usage.prompt_tokens += ret.usage.prompt_tokens
self.token_usage.completion_tokens= ret.usage.completion_tokens
self.token_usage.completion_tokens = ret.usage.completion_tokens
# TODO:
# 1. return logits as the scores
......@@ -355,7 +358,9 @@ class OpenAI(BaseBackend):
return decision, scores, None, None
def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
def openai_completion(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
for attempt in range(retries):
try:
if is_chat:
......@@ -385,15 +390,19 @@ def openai_completion(client, token_usage, is_chat=None, retries=3, prompt=None,
return comp
def openai_completion_stream(client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs):
def openai_completion_stream(
client, token_usage, is_chat=None, retries=3, prompt=None, **kwargs
):
for attempt in range(retries):
try:
if is_chat:
if "stop" in kwargs and kwargs["stop"] is None:
kwargs.pop("stop")
generator = client.chat.completions.create(
messages=prompt, stream=True, stream_options={"include_usage": True},
**kwargs
messages=prompt,
stream=True,
stream_options={"include_usage": True},
**kwargs,
)
for ret in generator:
if len(ret.choices) == 0:
......@@ -405,8 +414,10 @@ def openai_completion_stream(client, token_usage, is_chat=None, retries=3, promp
yield content or "", {}
else:
generator = client.completions.create(
prompt=prompt, stream=True, stream_options={"include_usage": True},
**kwargs
prompt=prompt,
stream=True,
stream_options={"include_usage": True},
**kwargs,
)
for ret in generator:
if len(ret.choices) == 0:
......
......@@ -507,7 +507,7 @@ class StreamExecutor:
)
return
else: # Speculative execution on models with completion interface
else: # Speculative execution on models with completion interface
comp, meta_info = self._spec_gen(sampling_params)
self.text_ += comp
......
......@@ -81,12 +81,10 @@ class SglSamplingParams:
"top_p": self.top_p,
"top_k": self.top_k,
}
def to_litellm_kwargs(self):
if self.regex is not None:
warnings.warn(
"Regular expression is not supported in the LiteLLM backend."
)
warnings.warn("Regular expression is not supported in the LiteLLM backend.")
return {
"max_tokens": self.max_new_tokens,
"stop": self.stop or None,
......
......@@ -10,4 +10,4 @@ if __name__ == "__main__":
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
launch_server(server_args, None)
\ No newline at end of file
launch_server(server_args, None)
"""Launch the inference server for Llava-video model."""
import argparse
import multiprocessing as mp
......
......@@ -4,7 +4,7 @@ from typing import Dict, Optional, Union
from outlines.caching import cache as disk_cache
from outlines.caching import disable_cache
from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm, make_byte_level_fsm
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from outlines.models.transformers import TransformerTokenizer
from pydantic import BaseModel
......
"""Cache for the compressed finite state machine."""
from sglang.srt.constrained import RegexGuide, TransformerTokenizer
from sglang.srt.constrained.base_cache import BaseCache
......
......@@ -8,11 +8,12 @@ from collections import defaultdict
import interegular
import outlines.caching
from sglang.srt.constrained import (
FSMInfo,
disk_cache,
make_deterministic_fsm,
make_byte_level_fsm,
make_deterministic_fsm,
)
from sglang.srt.constrained.base_cache import BaseCache
......
"""Conversation templates."""
# Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses
......
"""Utilities for Huggingface Transformers."""
import functools
import json
import os
import warnings
import functools
from typing import Optional, Union, AbstractSet, Collection, Literal
from typing import AbstractSet, Collection, Literal, Optional, Union
from huggingface_hub import snapshot_download
from transformers import (
......@@ -179,6 +179,7 @@ def get_processor(
class TiktokenTokenizer:
def __init__(self, tokenizer_path):
import tiktoken
PAT_STR_B = r"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
# Read JSON
......@@ -190,7 +191,8 @@ class TiktokenTokenizer:
bytes(item["bytes"]): item["token"] for item in tok_dict["regular_tokens"]
}
special_tokens = {
bytes(item["bytes"]).decode(): item["token"] for item in tok_dict["special_tokens"]
bytes(item["bytes"]).decode(): item["token"]
for item in tok_dict["special_tokens"]
}
assert tok_dict["word_split"] == "V1"
......@@ -202,7 +204,10 @@ class TiktokenTokenizer:
}
if "default_allowed_special" in tok_dict:
default_allowed_special = set(
[bytes(bytes_list).decode() for bytes_list in tok_dict["default_allowed_special"]]
[
bytes(bytes_list).decode()
for bytes_list in tok_dict["default_allowed_special"]
]
)
else:
default_allowed_special = None
......@@ -216,14 +221,20 @@ class TiktokenTokenizer:
self,
text: str,
*,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006
allowed_special: Union[
Literal["all"], AbstractSet[str]
] = set(), # noqa: B006
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
) -> list[int]:
if isinstance(allowed_special, set):
allowed_special |= self._default_allowed_special
return tiktoken.Encoding.encode(
self, text, allowed_special=allowed_special, disallowed_special=disallowed_special
self,
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
tokenizer.encode = functools.partial(encode_patched, tokenizer)
# Convert to HF interface
......@@ -237,10 +248,14 @@ class TiktokenTokenizer:
def decode(self, x):
return self.tokenizer.decode(x)
def batch_decode(self, batch, skip_special_tokens=True, spaces_between_special_tokens=False):
def batch_decode(
self, batch, skip_special_tokens=True, spaces_between_special_tokens=False
):
if isinstance(batch[0], int):
batch = [[x] for x in batch]
return self.tokenizer.decode_batch(batch)
def convert_ids_to_tokens(self, index):
return self.tokenizer.decode_single_token_bytes(index).decode("utf-8", errors="ignore")
\ No newline at end of file
return self.tokenizer.decode_single_token_bytes(index).decode(
"utf-8", errors="ignore"
)
......@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.utils import is_hip
......@@ -109,12 +108,16 @@ def fused_moe_kernel(
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)
a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
if use_fp8:
a_scale = tl.load(a_scale_ptr)
......@@ -130,13 +133,12 @@ def fused_moe_kernel(
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
if use_fp8:
accumulator = tl.dot(a, b, acc=accumulator)
......@@ -147,9 +149,7 @@ def fused_moe_kernel(
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
if use_fp8:
......@@ -159,15 +159,14 @@ def fused_moe_kernel(
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
topk_ids: torch.Tensor, block_size: int, num_experts: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
......@@ -206,32 +205,38 @@ def moe_align_block_size(
by block_size for proper block matrix operations.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
ops.moe_align_block_size(
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
)
return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int,
config: Dict[str, Any], compute_type: tl.dtype,
use_fp8: bool) -> None:
def invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: Dict[str, Any],
compute_type: tl.dtype,
use_fp8: bool,
) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
......@@ -242,8 +247,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
grid = lambda META: (
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
)
fused_moe_kernel[grid](
A,
......@@ -281,8 +288,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
@functools.lru_cache
def get_moe_configs(E: int, N: int,
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
......@@ -297,11 +303,11 @@ def get_moe_configs(E: int, N: int,
json_file_name = get_config_file_name(E, N, dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using configuration from %s for MoE layer.",
config_file_path)
logger.info("Using configuration from %s for MoE layer.", config_file_path)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
......@@ -352,40 +358,30 @@ def fused_moe(
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
M, _ = hidden_states.shape
E, N, _ = w1.shape
if is_hip():
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output,
dim=-1,
dtype=torch.float32)
routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
import vllm._moe_C as moe_kernels
topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
token_expert_indicies = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
moe_kernels.topk_softmax(
topk_weights,
topk_ids,
......@@ -400,8 +396,7 @@ def fused_moe(
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2],
"float8" if use_fp8 else None)
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
if configs:
# If an optimal configuration map has been found, look up the
......@@ -415,7 +410,7 @@ def fused_moe(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
"num_stages": 4,
}
if M <= E:
......@@ -425,61 +420,72 @@ def fused_moe(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 4
"num_stages": 4,
}
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache3 = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
compute_type = (tl.bfloat16
if hidden_states.dtype == torch.bfloat16 else tl.float16)
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8)
topk_ids, config["BLOCK_SIZE_M"], E
)
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
invoke_fused_moe_kernel(
hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8)
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
if inplace:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
\ No newline at end of file
return torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states,
)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
"""Logits processing."""
import torch
from torch import nn
from vllm.distributed import (
......
"""Radix attention."""
import torch
import numpy as np
import torch
from torch import nn
from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
......@@ -10,7 +11,9 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
class RadixAttention(nn.Module):
def __init__(self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1):
def __init__(
self, num_heads, head_dim, scaling, num_kv_heads, layer_id, logit_cap=-1
):
super().__init__()
self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads
......
......@@ -4,7 +4,7 @@ import asyncio
import logging
import queue
import threading
from typing import List, Callable
from typing import Callable, List
import uvloop
import zmq
......@@ -70,7 +70,9 @@ class DataParallelWorkerThread(threading.Thread):
# async sleep for receiving the subsequent request and avoiding cache miss
if len(out_pyobjs) != 0:
has_finished = any([obj.finished_reason is not None for obj in out_pyobjs])
has_finished = any(
[obj.finished_reason is not None for obj in out_pyobjs]
)
if has_finished:
await asyncio.sleep(self.request_dependency_delay)
await asyncio.sleep(global_config.wait_for_new_request_delay)
......@@ -108,4 +110,4 @@ def start_data_parallel_worker(
step_func=model_tp_client.step,
)
worker_thread.start()
return worker_thread
\ No newline at end of file
return worker_thread
"""Meta data for requests and batches"""
import warnings
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import List
import warnings
import numpy as np
import torch
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.managers.controller.radix_cache import RadixCache
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.constrained import RegexGuide
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment