Commit a130cf33 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.3.3' into vllm-v0.3.2-dtk23.10 and add gfx

parents a2d181be 82091b86
...@@ -87,7 +87,7 @@ def add_lora(y: torch.Tensor, ...@@ -87,7 +87,7 @@ def add_lora(y: torch.Tensor,
r = wb_t_all.size(-1) r = wb_t_all.size(-1)
if buffer is None: if buffer is None:
# We set the buffer to be float32 by default to avoid # We set the buffer to be float32 by default to avoid
# numerical innacuracies that would otherwise happen # numerical inaccuracies that would otherwise happen
# due to downcasting. # due to downcasting.
buffer = torch.zeros((x.size(0), r), buffer = torch.zeros((x.size(0), r),
dtype=torch.float32, dtype=torch.float32,
......
from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.utils import set_random_seed from vllm.model_executor.utils import set_random_seed, get_model
__all__ = [ __all__ = [
"InputMetadata", "InputMetadata",
......
import asyncio
import concurrent.futures
from copy import copy
from enum import Enum
from functools import lru_cache
from json import dumps as json_dumps
from re import escape as regex_escape
from typing import Union, Tuple
from pydantic import BaseModel
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest
from vllm.model_executor.guided_logits_processors import JSONLogitsProcessor, RegexLogitsProcessor
class GuidedDecodingMode(Enum):
JSON = "json"
REGEX = "regex"
CHOICE = "choice"
global_thread_pool = None # used for generating logits processor fsm
async def get_guided_decoding_logits_processor(
request: Union[CompletionRequest, ChatCompletionRequest],
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
"""
Given an OpenAI-compatible request, check for guided decoding parameters
and get the necessary logits processor for the given guide.
We cache logit processors by (guide, tokenizer), and on cache hit
we make a shallow copy to reuse the same underlying FSM.
"""
global global_thread_pool
guide, mode = _get_guide_and_mode(request)
if not guide:
return None
if global_thread_pool is None:
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=2)
loop = asyncio.get_running_loop()
result = await loop.run_in_executor(global_thread_pool,
_get_cached_logits_processor, guide,
tokenizer, mode)
logits_processor = copy(result)
# reset logits processor's internal state
logits_processor.init_state()
return logits_processor
def _get_guide_and_mode(
request: Union[CompletionRequest, ChatCompletionRequest]
) -> Tuple[str, GuidedDecodingMode]:
if request.guided_json:
if not isinstance(request.guided_json, (str, dict, BaseModel)):
raise TypeError("JSON schema must be str, dict, or BaseModel")
json = request.guided_json
if isinstance(json, dict):
# turn dict into hashable string
json = json_dumps(json, sort_keys=True)
elif isinstance(json, BaseModel):
# use pydantic signature so that different model classes
# with the same fields will get hashed the same
json = str(json.__signature__)
return json, GuidedDecodingMode.JSON
elif request.guided_regex:
if not isinstance(request.guided_regex, str):
raise TypeError("Regex must be string")
return request.guided_regex, GuidedDecodingMode.REGEX
elif request.guided_choice:
if not isinstance(request.guided_choice, list):
raise TypeError("Choices must be a list")
# choice just uses regex
choices = [
regex_escape(str(choice)) for choice in request.guided_choice
]
choices_regex = "(" + "|".join(choices) + ")"
return choices_regex, GuidedDecodingMode.CHOICE
else:
return None, None
@lru_cache(maxsize=32)
def _get_cached_logits_processor(guide: str, tokenizer,
mode: GuidedDecodingMode):
if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer)
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
return RegexLogitsProcessor(guide, tokenizer)
else:
raise ValueError(f"Unknown guided decoding mode {mode}")
# Copyright 2024- the Outlines developers
# This file is adapted from
# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
from collections import defaultdict
from typing import Union, DefaultDict, Dict, List, Optional
import torch
from pydantic import BaseModel
from outlines.fsm.fsm import RegexFSM
from outlines.fsm.json_schema import build_regex_from_schema
class RegexLogitsProcessor:
def __init__(self, regex_string: str, tokenizer):
"""Compile the FSM that drives the regex-structured generation.
Parameters
----------
regex_string
A string that represents a regular expression
tokenizer
The model's tokenizer
"""
tokenizer = self.adapt_tokenizer(tokenizer)
fsm = RegexFSM(regex_string, tokenizer)
self.fsm = fsm
def init_state(self):
"""Initialize the FSM states."""
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""
seq_id = hash(tuple(input_ids))
if len(input_ids) == 0:
self.init_state()
else:
last_token = input_ids[-1]
last_seq_id = hash(tuple(input_ids[:-1]))
self.fsm_state[seq_id] = self.fsm.next_state(
self.fsm_state[last_seq_id], last_token)
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
mask = torch.full((scores.shape[-1], ),
-math.inf,
device=scores.device)
mask[allowed_tokens] = 0
scores.add_(mask)
return scores
def adapt_tokenizer(self, tokenizer):
"""Adapt vLLM's tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of
`transformers`. In addition we need to handle the missing spaces to
Llama's tokenizer to be able to compile FSMs for this model.
"""
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
def convert_token_to_string(token: str) -> str:
from transformers.file_utils import SPIECE_UNDERLINE
string = tokenizer.convert_tokens_to_string([token])
# A hack to handle missing spaces to HF's Llama tokenizers
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
return " " + string
return string
tokenizer.convert_token_to_string = convert_token_to_string
return tokenizer
class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self,
schema: Union[str, Dict, BaseModel],
tokenizer,
whitespace_pattern: Optional[str] = None):
"""Compile the FSM that drives the JSON-guided generation.
Parameters
----------
schema
A JSON schema that encodes the structure we want the model to generate
tokenizer
The model's tokenizer
whitespace_pattern
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
"""
if isinstance(schema, type(BaseModel)):
schema_str = json.dumps(schema.model_json_schema())
elif isinstance(schema, Dict):
schema_str = json.dumps(schema)
elif isinstance(schema, str):
schema_str = schema
else:
raise ValueError(
f"Cannot parse schema {schema}. The schema must be either " +
"a Pydantic object, a dictionary or a string that contains the JSON "
+ "Schema specification")
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
super().__init__(regex_string, tokenizer)
...@@ -37,6 +37,29 @@ class SiluAndMul(nn.Module): ...@@ -37,6 +37,29 @@ class SiluAndMul(nn.Module):
return out return out
class GeluAndMul(nn.Module):
"""An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
Shapes:
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
return: (batch_size, seq_len, d) or (num_tokens, d)
"""
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.gelu(x[..., :d]) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
ops.gelu_and_mul(out, x)
return out
class NewGELU(nn.Module): class NewGELU(nn.Module):
def _forward(self, x: torch.Tensor) -> torch.Tensor: def _forward(self, x: torch.Tensor) -> torch.Tensor:
......
...@@ -137,25 +137,27 @@ class PagedAttention(nn.Module): ...@@ -137,25 +137,27 @@ class PagedAttention(nn.Module):
) )
if input_metadata.is_prompt: if input_metadata.is_prompt:
# Prompt run.
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :, None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# normal attention # normal attention
if (key_cache is None or value_cache is None if (key_cache is None or value_cache is None
or input_metadata.block_tables.numel() == 0): or input_metadata.block_tables.numel() == 0):
if self.num_kv_heads != self.num_heads:
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :,
None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# Set attention bias if not provided. This typically happens at # Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration. # the very attention layer of every iteration.
# FIXME(woosuk): This is a hack. # FIXME(woosuk): This is a hack.
......
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
__all__ = [
"fused_moe",
]
{
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7},
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6},
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7},
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 7},
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"64": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"96": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6},
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 6},
"256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4},
"512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 8, "num_stages": 4},
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4},
"2048": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
"3072": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4},
"4096": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16, "num_warps": 8, "num_stages": 4}
}
{
"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
"2": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"4": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 256, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 4},
"16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
"24": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4},
"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"80": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"96": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"128": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 4},
"200": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4},
"208": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 4},
"216": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 16, "num_warps": 4, "num_stages": 4},
"224": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"256": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 4},
"512": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"1024": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"1536": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"2048": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"3072": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4},
"4096": {"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 64, "num_warps": 8, "num_stages": 4}
}
This directory contains tuned configurations for different settings of the fused_moe kernel.
For different settings of
- E (number of experts)
- N (intermediate size)
- device_name (torch.cuda.get_device_name())
the JSON file contains a mapping from M (batch size) to the chosen configuration.
The example configurations provided are for the Mixtral model for TP2 on H100
and TP4 on A100. Mixtral has intermediate size N = 14336, i.e. for TP2 we have
N = 7168 and for TP4 we have N = 3584.
"""Fused MoE kernel.""" """Fused MoE kernel."""
import functools
import json
import os
from typing import Any, Dict, Optional, Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm._C import ops from vllm._C import ops
from vllm.logger import init_logger
from vllm.utils import is_hip from vllm.utils import is_hip
logger = init_logger(__name__)
@triton.jit @triton.jit
def fused_moe_kernel( def fused_moe_kernel(
...@@ -129,7 +137,7 @@ def fused_moe_kernel( ...@@ -129,7 +137,7 @@ def fused_moe_kernel(
def moe_align_block_size( def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int, topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> (torch.Tensor, torch.Tensor, torch.Tensor): num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block size for matrix multiplication. Aligns the token distribution across experts to be compatible with block size for matrix multiplication.
...@@ -177,7 +185,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ...@@ -177,7 +185,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int, config: dict): mul_routed_weight: bool, top_k: int,
config: Dict[str, Any]) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
...@@ -210,6 +219,34 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ...@@ -210,6 +219,34 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
) )
@functools.lru_cache
def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
The return value will be a dictionary that maps an irregular grid of batch sizes
to configurations of the fused_moe kernel. To evaluate the kernel on a given batch
size bs, the closest batch size in the grid should be picked and the associated
configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs directory
device_name = torch.cuda.get_device_name().replace(" ", "_")
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs",
f"E={E},N={N},device_name={device_name}.json")
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info(
f"Using configuration from {config_file_path} for MoE layer.")
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
# If no optimized configuration is available, we will use the default configuration
return None
def fused_moe( def fused_moe(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
...@@ -218,6 +255,7 @@ def fused_moe( ...@@ -218,6 +255,7 @@ def fused_moe(
topk: int, topk: int,
renormalize: bool, renormalize: bool,
inplace: bool = False, inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
...@@ -230,6 +268,7 @@ def fused_moe( ...@@ -230,6 +268,7 @@ def fused_moe(
- topk (int): The number of top-k experts to select. - topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1. - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False. - inplace (bool): If True, perform the operation in-place. Defaults to False.
- override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration.
Returns: Returns:
- torch.Tensor: The output tensor after applying the MoE layer. - torch.Tensor: The output tensor after applying the MoE layer.
...@@ -279,20 +318,31 @@ def fused_moe( ...@@ -279,20 +318,31 @@ def fused_moe(
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
config = { if override_config:
'BLOCK_SIZE_M': 64, config = override_config
'BLOCK_SIZE_N': 64, else:
'BLOCK_SIZE_K': 32, # First try to load optimal config from the file
'GROUP_SIZE_M': 8 configs = get_moe_configs(E, w2.shape[2])
}
if configs:
if topk_ids.numel() <= w1.shape[0]: # If an optimal configuration map has been found, look up the optimal config
config = { config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
'BLOCK_SIZE_M': 16, else:
'BLOCK_SIZE_N': 32, # Else use the default config
'BLOCK_SIZE_K': 64, config = {
'GROUP_SIZE_M': 1 'BLOCK_SIZE_M': 64,
} 'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
}
if M <= E:
config = {
'BLOCK_SIZE_M': 16,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 1
}
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device, device=hidden_states.device,
......
...@@ -17,6 +17,14 @@ from vllm.logger import init_logger ...@@ -17,6 +17,14 @@ from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
def adjust_marlin_shard(param, shard_size, shard_offset):
marlin_tile_size = getattr(param, "marlin_tile_size", None)
if marlin_tile_size is None:
return shard_size, shard_offset
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
class LinearMethodBase(ABC): class LinearMethodBase(ABC):
"""Base class for different (maybe quantized) linear methods.""" """Base class for different (maybe quantized) linear methods."""
...@@ -276,6 +284,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -276,6 +284,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if packed_dim == output_dim: if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
loaded_weight_shard = loaded_weight.narrow( loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size) output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id) self.weight_loader(param, loaded_weight_shard, shard_id)
...@@ -293,6 +306,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -293,6 +306,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if packed_dim == output_dim: if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
...@@ -372,6 +390,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -372,6 +390,7 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id: Optional[str] = None): loaded_shard_id: Optional[str] = None):
param_data = param.data param_data = param.data
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
if loaded_shard_id is None: if loaded_shard_id is None:
# Loaded weight is already packed. # Loaded weight is already packed.
if output_dim is None: if output_dim is None:
...@@ -393,6 +412,11 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -393,6 +412,11 @@ class QKVParallelLinear(ColumnParallelLinear):
if packed_dim == output_dim: if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
loaded_weight_shard = loaded_weight.narrow( loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size) output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id) self.weight_loader(param, loaded_weight_shard, shard_id)
...@@ -417,6 +441,11 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -417,6 +441,11 @@ class QKVParallelLinear(ColumnParallelLinear):
if packed_dim == output_dim: if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor shard_offset = shard_offset // param.pack_factor
# If marlin, we need to adjust the offset and size to account for the tiling.
shard_size, shard_offset = adjust_marlin_shard(
param, shard_size, shard_offset)
param_data = param_data.narrow(output_dim, shard_offset, param_data = param_data.narrow(output_dim, shard_offset,
shard_size) shard_size)
if loaded_shard_id == "q": if loaded_shard_id == "q":
......
...@@ -4,11 +4,13 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf ...@@ -4,11 +4,13 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
_QUANTIZATION_CONFIG_REGISTRY = { _QUANTIZATION_CONFIG_REGISTRY = {
"awq": AWQConfig, "awq": AWQConfig,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig, "squeezellm": SqueezeLLMConfig,
"marlin": MarlinConfig,
} }
......
import enum import enum
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from fractions import Fraction
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -27,11 +28,10 @@ class GPTQConfig(QuantizationConfig): ...@@ -27,11 +28,10 @@ class GPTQConfig(QuantizationConfig):
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.group_size = group_size self.group_size = group_size
self.desc_act = desc_act self.desc_act = desc_act
self.pack_factor = 32 // self.weight_bits self.pack_factor = Fraction(32, self.weight_bits)
# exllama kernel v1 only supports 4 bit if self.weight_bits not in [2, 3, 4, 8]:
if self.weight_bits != 4:
raise ValueError( raise ValueError(
"Currently, only 4-bit weight quantization is supported for " "Currently, only 2/3/4/8-bit weight quantization is supported for "
f"GPTQ, but got {self.weight_bits} bits.") f"GPTQ, but got {self.weight_bits} bits.")
def __repr__(self) -> str: def __repr__(self) -> str:
...@@ -101,7 +101,7 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -101,7 +101,7 @@ class GPTQLinearMethod(LinearMethodBase):
"The input size is not aligned with the quantized " "The input size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
"tensor parallel size.") "tensor parallel size.")
if output_size_per_partition % self.quant_config.pack_factor != 0: if output_size_per_partition % self.quant_config.pack_factor.numerator != 0:
raise ValueError( raise ValueError(
"The output size is not aligned with the quantized " "The output size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
...@@ -201,11 +201,13 @@ class GPTQLinearMethod(LinearMethodBase): ...@@ -201,11 +201,13 @@ class GPTQLinearMethod(LinearMethodBase):
else: else:
weights["g_idx"] = torch.empty((1, 1), device="meta") weights["g_idx"] = torch.empty((1, 1), device="meta")
weights["exllama_state"] = ExllamaState.READY weights["exllama_state"] = ExllamaState.READY
ops.gptq_shuffle(weights["qweight"], weights["g_idx"]) ops.gptq_shuffle(weights["qweight"], weights["g_idx"],
self.quant_config.weight_bits)
output = ops.gptq_gemm(reshaped_x, weights["qweight"], output = ops.gptq_gemm(reshaped_x, weights["qweight"],
weights["qzeros"], weights["scales"], weights["qzeros"], weights["scales"],
weights["g_idx"], weights["g_idx"],
weights["exllama_state"] == ExllamaState.READY) weights["exllama_state"] == ExllamaState.READY,
self.quant_config.weight_bits)
if bias is not None: if bias is not None:
output = output + bias output = output + bias
return output.reshape(out_shape) return output.reshape(out_shape)
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm._C import ops
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
class MarlinConfig(QuantizationConfig):
"""Config class for Marlin.
Reference: https://github.com/IST-DASLab/marlin/tree/master
"""
def __init__(
self,
group_size: int,
) -> None:
# Group size for the quantization.
self.group_size = group_size
if self.group_size != 128 and self.group_size != -1:
raise ValueError(
"Currently, only group size 128 and -1 (channelwise) is supported for "
f"Marlin, but got group_size of {self.group_size}")
# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // 4
# Tile size used by marlin kernels.
self.tile_size = 16
# Min out_features dim
self.min_n_threads = 64
# Min in_features dim
self.min_k_threads = 128
# Max parallel problems to solve at once (improves large batch performance)
self.max_parallel = 16
# Permutation length used by the marlin kernels.
self.perm_len = 1024
def __repr__(self) -> str:
return f"MarlinConfig(group_size={self.group_size}"
@classmethod
def get_name(cls) -> str:
return "marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
return cls(group_size)
def get_linear_method(self) -> "MarlinLinearMethod":
return MarlinLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
class MarlinLinearMethod(LinearMethodBase):
"""Linear method for Marlin.
Args:
quant_config: The Marlin quantization config.
"""
def __init__(self, quant_config: MarlinConfig):
self.quant_config = quant_config
def create_weights(
self,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
del output_size # Unused.
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
# Validate output_size_per_partition
if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError(
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by min_n_threads = {self.quant_config.min_n_threads}."
)
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by pack_factor = {self.quant_config.pack_factor}."
)
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_k_threads != 0:
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition} is not divisible by min_k_threads = {self.quant_config.min_k_threads}."
)
if self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}."
)
# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm = self.quant_config.perm_len // (
self.quant_config.tile_size**2)
if output_size_per_partition % num_tiles_per_perm != 0:
raise ValueError(
"Each permutation group must reside on the same gpu")
# Quantized 4Bit weights packed into Int32.
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.tile_size,
output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"marlin_tile_size": self.quant_config.tile_size,
},
)
# Determine if channelwise or not
input_groups = 1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size
scales = Parameter(
torch.empty(
input_groups,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"input_dim": None if input_groups == 1 else 0,
"output_dim": 1,
},
)
# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = Parameter(torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
requires_grad=False)
return {
"B": qweight,
"s": scales,
"workspace": workspace,
}
def apply_weights(
self,
weights: Dict[str, Any],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = weights["B"]
scales = weights["s"]
workspace = weights["workspace"]
x_2d = x.view(-1, x.shape[-1])
size_m = x_2d.shape[0]
size_k = x_2d.shape[1]
size_n = scales.shape[1]
output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
size_n, size_k)
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
if bias is not None:
output.add_(bias) # In-place add
return output
...@@ -245,13 +245,11 @@ def _yarn_find_correction_range(low_rot: int, ...@@ -245,13 +245,11 @@ def _yarn_find_correction_range(low_rot: int,
def _yarn_linear_ramp_mask(low: float, high: float, dim: int, def _yarn_linear_ramp_mask(low: float, high: float, dim: int,
dtype: torch.dtype, dtype: torch.dtype) -> torch.Tensor:
device: torch.device) -> torch.Tensor:
if low == high: if low == high:
high += 0.001 # Prevent singularity high += 0.001 # Prevent singularity
linear_func = (torch.arange(dim, dtype=dtype, device=device) - linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
low) / (high - low)
ramp_func = torch.clamp(linear_func, 0, 1) ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func return ramp_func
...@@ -356,7 +354,6 @@ def get_rope( ...@@ -356,7 +354,6 @@ def get_rope(
elif scaling_type == "yarn": elif scaling_type == "yarn":
original_max_position = rope_scaling[ original_max_position = rope_scaling[
"original_max_position_embeddings"] "original_max_position_embeddings"]
assert max_position == original_max_position * scaling_factor
extra_kwargs = { extra_kwargs = {
k: v k: v
for k, v in rope_scaling.items() for k, v in rope_scaling.items()
......
...@@ -10,6 +10,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTens ...@@ -10,6 +10,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTens
from vllm.sampling_params import SamplingParams, SamplingType from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput,
SequenceData, SequenceGroupOutput, SequenceOutput) SequenceData, SequenceGroupOutput, SequenceOutput)
from vllm.utils import is_neuron
class Sampler(nn.Module): class Sampler(nn.Module):
...@@ -32,6 +33,8 @@ class Sampler(nn.Module): ...@@ -32,6 +33,8 @@ class Sampler(nn.Module):
org_vocab_size: Optional[int] = None) -> None: org_vocab_size: Optional[int] = None) -> None:
super().__init__() super().__init__()
self.vocab_size = vocab_size self.vocab_size = vocab_size
# Transformers-neuronx generate outputs as logits directly.
self.logits_as_hidden_states = is_neuron()
# original vocabulary size (without LoRA). # original vocabulary size (without LoRA).
self.org_vocab_size = org_vocab_size or vocab_size self.org_vocab_size = org_vocab_size or vocab_size
...@@ -55,10 +58,14 @@ class Sampler(nn.Module): ...@@ -55,10 +58,14 @@ class Sampler(nn.Module):
embedding_bias: Optional[torch.Tensor] = None, embedding_bias: Optional[torch.Tensor] = None,
) -> Optional[SamplerOutput]: ) -> Optional[SamplerOutput]:
# Get the hidden states that we use for sampling. # Get the hidden states that we use for sampling.
hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) if self.logits_as_hidden_states:
logits = hidden_states
else:
hidden_states = _prune_hidden_states(hidden_states,
sampling_metadata)
# Get the logits for the next tokens. # Get the logits for the next tokens.
logits = self._get_logits(hidden_states, embedding, embedding_bias) logits = self._get_logits(hidden_states, embedding, embedding_bias)
# Only perform sampling in the driver worker. # Only perform sampling in the driver worker.
# Note: `_get_logits` is still distributed across TP workers because # Note: `_get_logits` is still distributed across TP workers because
...@@ -395,7 +402,8 @@ def _sample( ...@@ -395,7 +402,8 @@ def _sample(
sample_metadata[sampling_type] = (seq_group_ids, seq_groups, sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices) is_prompts, sample_indices)
if sampling_type == SamplingType.GREEDY: if sampling_type == SamplingType.GREEDY:
greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1) greedy_samples = torch.argmax(logprobs[sample_indices.long()],
dim=-1)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of = 1 max_best_of = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts): for seq_group, is_prompt in zip(seq_groups, is_prompts):
...@@ -407,7 +415,7 @@ def _sample( ...@@ -407,7 +415,7 @@ def _sample(
"generators": sampling_metadata.generators, "generators": sampling_metadata.generators,
} }
multinomial_samples[sampling_type] = _multinomial( multinomial_samples[sampling_type] = _multinomial(
probs[sample_indices], max_best_of, **seeded_args) probs[sample_indices.long()], max_best_of, **seeded_args)
elif sampling_type == SamplingType.BEAM: elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices] beam_search_logprobs = logprobs[sample_indices]
else: else:
......
...@@ -45,6 +45,7 @@ if triton.__version__ >= "2.1.0": ...@@ -45,6 +45,7 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_h, stride_v_cache_h,
stride_v_cache_d, stride_v_cache_d,
stride_v_cache_bl, stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
...@@ -53,6 +54,8 @@ if triton.__version__ >= "2.1.0": ...@@ -53,6 +54,8 @@ if triton.__version__ >= "2.1.0":
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
start_m = tl.program_id(2) start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
...@@ -85,13 +88,14 @@ if triton.__version__ >= "2.1.0": ...@@ -85,13 +88,14 @@ if triton.__version__ >= "2.1.0":
mask=(start_n + offs_n) < cur_batch_ctx_len, mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0) other=0)
off_k = (bn[None, :] * stride_k_cache_bs + off_k = (bn[None, :] * stride_k_cache_bs +
cur_head * stride_k_cache_h + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d + (offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) * ((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl + stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x) (offs_d[:, None] % x) * stride_k_cache_x)
off_v = ( off_v = (
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d + offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k, k = tl.load(K_cache + off_k,
...@@ -131,9 +135,9 @@ if triton.__version__ >= "2.1.0": ...@@ -131,9 +135,9 @@ if triton.__version__ >= "2.1.0":
l_i = l_i_new l_i = l_i_new
m_i = m_i_new m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd) offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd) offs_d[None, :] * stride_vd)
k_ptrs = K + off_k k_ptrs = K + off_k
v_ptrs = V + off_v v_ptrs = V + off_v
...@@ -232,6 +236,7 @@ if triton.__version__ >= "2.1.0": ...@@ -232,6 +236,7 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_h, stride_v_cache_h,
stride_v_cache_d, stride_v_cache_d,
stride_v_cache_bl, stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
...@@ -240,6 +245,8 @@ if triton.__version__ >= "2.1.0": ...@@ -240,6 +245,8 @@ if triton.__version__ >= "2.1.0":
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
start_m = tl.program_id(2) start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
...@@ -272,13 +279,14 @@ if triton.__version__ >= "2.1.0": ...@@ -272,13 +279,14 @@ if triton.__version__ >= "2.1.0":
mask=(start_n + offs_n) < cur_batch_ctx_len, mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0) other=0)
off_k = (bn[None, :] * stride_k_cache_bs + off_k = (bn[None, :] * stride_k_cache_bs +
cur_head * stride_k_cache_h + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d + (offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) * ((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl + stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x) (offs_d[:, None] % x) * stride_k_cache_x)
off_v = ( off_v = (
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d + offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k, k = tl.load(K_cache + off_k,
...@@ -317,9 +325,9 @@ if triton.__version__ >= "2.1.0": ...@@ -317,9 +325,9 @@ if triton.__version__ >= "2.1.0":
l_i = l_i_new l_i = l_i_new
m_i = m_i_new m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd) offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd) offs_d[None, :] * stride_vd)
k_ptrs = K + off_k k_ptrs = K + off_k
v_ptrs = V + off_v v_ptrs = V + off_v
...@@ -420,6 +428,7 @@ if triton.__version__ >= "2.1.0": ...@@ -420,6 +428,7 @@ if triton.__version__ >= "2.1.0":
stride_v_cache_h, stride_v_cache_h,
stride_v_cache_d, stride_v_cache_d,
stride_v_cache_bl, stride_v_cache_bl,
num_queries_per_kv: int,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
...@@ -429,6 +438,8 @@ if triton.__version__ >= "2.1.0": ...@@ -429,6 +438,8 @@ if triton.__version__ >= "2.1.0":
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
start_m = tl.program_id(2) start_m = tl.program_id(2)
cur_kv_head = cur_head // num_queries_per_kv
# cur_batch_seq_len: the length of prompts # cur_batch_seq_len: the length of prompts
# cur_batch_ctx_len: the length of prefix # cur_batch_ctx_len: the length of prefix
# cur_batch_in_all_start_index: the start id of the dim=0 # cur_batch_in_all_start_index: the start id of the dim=0
...@@ -468,13 +479,14 @@ if triton.__version__ >= "2.1.0": ...@@ -468,13 +479,14 @@ if triton.__version__ >= "2.1.0":
mask=(start_n + offs_n) < cur_batch_ctx_len, mask=(start_n + offs_n) < cur_batch_ctx_len,
other=0) other=0)
off_k = (bn[None, :] * stride_k_cache_bs + off_k = (bn[None, :] * stride_k_cache_bs +
cur_head * stride_k_cache_h + cur_kv_head * stride_k_cache_h +
(offs_d[:, None] // x) * stride_k_cache_d + (offs_d[:, None] // x) * stride_k_cache_d +
((start_n + offs_n[None, :]) % block_size) * ((start_n + offs_n[None, :]) % block_size) *
stride_k_cache_bl + stride_k_cache_bl +
(offs_d[:, None] % x) * stride_k_cache_x) (offs_d[:, None] % x) * stride_k_cache_x)
off_v = ( off_v = (
bn[:, None] * stride_v_cache_bs + cur_head * stride_v_cache_h + bn[:, None] * stride_v_cache_bs +
cur_kv_head * stride_v_cache_h +
offs_d[None, :] * stride_v_cache_d + offs_d[None, :] * stride_v_cache_d +
(start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl)
k = tl.load(K_cache + off_k, k = tl.load(K_cache + off_k,
...@@ -522,9 +534,9 @@ if triton.__version__ >= "2.1.0": ...@@ -522,9 +534,9 @@ if triton.__version__ >= "2.1.0":
l_i = l_i_new l_i = l_i_new
m_i = m_i_new m_i = m_i_new
off_k = (offs_n[None, :] * stride_kbs + cur_head * stride_kh + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh +
offs_d[:, None] * stride_kd) offs_d[:, None] * stride_kd)
off_v = (offs_n[:, None] * stride_vbs + cur_head * stride_vh + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh +
offs_d[None, :] * stride_vd) offs_d[None, :] * stride_vd)
k_ptrs = K + off_k k_ptrs = K + off_k
v_ptrs = V + off_v v_ptrs = V + off_v
...@@ -537,7 +549,7 @@ if triton.__version__ >= "2.1.0": ...@@ -537,7 +549,7 @@ if triton.__version__ >= "2.1.0":
alibi_start_q = tl.arange( alibi_start_q = tl.arange(
0, BLOCK_M) + block_start_loc + cur_batch_ctx_len 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len
alibi_start_k = cur_batch_ctx_len alibi_start_k = cur_batch_ctx_len
# # init debuger # # init debugger
# offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc
# offset_db_k = tl.arange(0, BLOCK_N) # offset_db_k = tl.arange(0, BLOCK_N)
# calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL]
...@@ -628,6 +640,7 @@ if triton.__version__ >= "2.1.0": ...@@ -628,6 +640,7 @@ if triton.__version__ >= "2.1.0":
sm_scale = 1.0 / (Lq**0.5) sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1] batch, head = b_seq_len.shape[0], q.shape[1]
num_queries_per_kv = q.shape[1] // k.shape[1]
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
...@@ -674,6 +687,7 @@ if triton.__version__ >= "2.1.0": ...@@ -674,6 +687,7 @@ if triton.__version__ >= "2.1.0":
v_cache.stride(2), v_cache.stride(2),
v_cache.stride( v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size] 3), #[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv=num_queries_per_kv,
BLOCK_M=BLOCK, BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk, BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
...@@ -721,6 +735,7 @@ if triton.__version__ >= "2.1.0": ...@@ -721,6 +735,7 @@ if triton.__version__ >= "2.1.0":
v_cache.stride(2), v_cache.stride(2),
v_cache.stride( v_cache.stride(
3), #[num_blocks, num_kv_heads, head_size, block_size] 3), #[num_blocks, num_kv_heads, head_size, block_size]
num_queries_per_kv=num_queries_per_kv,
BLOCK_M=BLOCK, BLOCK_M=BLOCK,
BLOCK_DMODEL=Lk, BLOCK_DMODEL=Lk,
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
......
"""Utilities for selecting and loading models.""" """Utilities for selecting and loading models."""
import contextlib import contextlib
from typing import Optional, Type from typing import Type
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.config import DeviceConfig, ModelConfig, LoRAConfig from vllm.config import DeviceConfig, ModelConfig
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config, from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights) initialize_dummy_weights)
...@@ -37,9 +37,9 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: ...@@ -37,9 +37,9 @@ def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]:
f"Supported architectures: {ModelRegistry.get_supported_archs()}") f"Supported architectures: {ModelRegistry.get_supported_archs()}")
def get_model(model_config: ModelConfig, def get_model(model_config: ModelConfig, device_config: DeviceConfig,
device_config: DeviceConfig, **kwargs) -> nn.Module:
lora_config: Optional[LoRAConfig] = None) -> nn.Module: lora_config = kwargs.get("lora_config", None)
model_class = _get_model_architecture(model_config) model_class = _get_model_architecture(model_config)
# Get the (maybe quantized) linear method. # Get the (maybe quantized) linear method.
......
...@@ -4,7 +4,7 @@ from typing import List, Optional, Type ...@@ -4,7 +4,7 @@ from typing import List, Optional, Type
import torch.nn as nn import torch.nn as nn
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_hip from vllm.utils import is_hip, is_neuron
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -30,7 +30,7 @@ _MODELS = { ...@@ -30,7 +30,7 @@ _MODELS = {
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"), "LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-* # For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("mistral", "MistralForCausalLM"), "MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
# transformers's mpt class has lower case # transformers's mpt class has lower case
...@@ -38,11 +38,14 @@ _MODELS = { ...@@ -38,11 +38,14 @@ _MODELS = {
"MPTForCausalLM": ("mpt", "MPTForCausalLM"), "MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"OLMoForCausalLM": ("olmo", "OLMoForCausalLM"), "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"), "OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
} }
# Models not supported by ROCm. # Models not supported by ROCm.
...@@ -59,6 +62,9 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = { ...@@ -59,6 +62,9 @@ _ROCM_PARTIALLY_SUPPORTED_MODELS = {
"Sliding window attention is not yet supported in ROCm's flash attention", "Sliding window attention is not yet supported in ROCm's flash attention",
} }
# Models not supported by Neuron.
_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"}
class ModelRegistry: class ModelRegistry:
...@@ -75,8 +81,15 @@ class ModelRegistry: ...@@ -75,8 +81,15 @@ class ModelRegistry:
logger.warning( logger.warning(
f"Model architecture {model_arch} is partially supported " f"Model architecture {model_arch} is partially supported "
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
elif is_neuron():
if model_arch not in _NEURON_SUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"Neuron for now.")
module_name, model_cls_name = _MODELS[model_arch] module_name, model_cls_name = _MODELS[model_arch]
if is_neuron():
module_name = _NEURON_SUPPORTED_MODELS[model_arch]
module = importlib.import_module( module = importlib.import_module(
f"vllm.model_executor.models.{module_name}") f"vllm.model_executor.models.{module_name}")
return getattr(module, model_cls_name, None) return getattr(module, model_cls_name, None)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment