Commit 2216a4e5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/main'

parents ad385667 51c24c97
......@@ -16,7 +16,6 @@ Typical output looks like this:
2.6 weighted s to build ...torch_bindings.cpp.o (31.5 s elapsed time)
3.2 weighted s to build ...torch_bindings.cpp.o (38.5 s elapsed time)
Longest build steps for .so (linking):
0.1 weighted s to build _core_C.abi3.so (0.7 s elapsed time)
0.1 weighted s to build _moe_C.abi3.so (1.0 s elapsed time)
0.5 weighted s to build ...flash_attn_c.abi3.so (1.1 s elapsed time)
6.2 weighted s to build _C.abi3.so (6.2 s elapsed time)
......
import importlib.util
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
core_C_available = importlib.util.find_spec('._core_C', 'vllm') is not None
# Mirrors enum in `core/scalar_type.hpp`
class NanRepr(Enum):
NONE = 0 # nans are not supported
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
if TYPE_CHECKING or not core_C_available:
# On platforms were we cannot use/build the C++ core extension (i.e. namely
# neuron and tpu), we define the mock ScalarType class here that partially
# mimics the C++ ScalarType class.
#
# We also use this provide type signatures to the Python LSP for the methods
# in the C++ ScalarType class. So these type signatures should be kept
# in sync with csrc/core/scalar_type.hpp
from dataclasses import dataclass
@dataclass(frozen=True)
class ScalarType:
"""
ScalarType can represent a wide range of floating point and integer
types, in particular it can be used to represent sub-byte data types
(something that torch.dtype currently does not support). It is also
capable of representing types with a bias, i.e.:
`stored_value = value + bias`,
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
of 8). The implementation for this class can be found in
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
with that file.
"""
exponent: int
"""
Number of bits in the exponent if this is a floating point type
(zero if this an integer type)
"""
mantissa: int
"""
Number of bits in the mantissa if this is a floating point type,
or the number bits representing an integer excluding the sign bit if
this an integer type.
"""
bias: int
"""
bias used to encode the values in this scalar type
(value = stored_value - bias, default 0) for example if we store the
type as an unsigned integer with a bias of 128 then the value 0 will be
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
"""
signed: bool
"If the type is signed (i.e. has a sign bit)"
_finite_values_only: bool = False
"""
Private: if NANs are supported, used `has_infs()` instead.
"""
nan_repr: int = NanRepr.IEEE_754.value
"""
How NaNs are represent in this scalar type, returns NanRepr value.
(not applicable for integer types)
"""
@property
def size_bits(self):
return self.exponent + self.mantissa + int(self.signed)
def min(self) -> Union[int, float]:
"""
Min representable value for this scalar type.
(accounting for bias if there is one)
"""
raise NotImplementedError
def max(self) -> Union[int, float]:
"""
Max representable value for this scalar type.
(accounting for bias if there is one)
"""
raise NotImplementedError
def is_signed(self) -> bool:
"""
If the type is signed (i.e. has a sign bit), same as `signed`
added for consistency with:
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
"""
...
def is_floating_point(self) -> bool:
"If the type is a floating point type"
return self.exponent != 0
def is_integer(self) -> bool:
"If the type is an integer type"
return self.exponent == 0
def has_bias(self) -> bool:
"If the type has a non-zero bias"
return self.bias != 0
def has_infs(self) -> bool:
"If the type is floating point and supports infinity"
return not self._finite_values_only
def has_nans(self) -> bool:
return self.nan_repr != NanRepr.NONE.value
def is_ieee_754(self) -> bool:
"""
If the type is a floating point type that follows IEEE 754
conventions
"""
return self.nan_repr == NanRepr.IEEE_754.value and \
not self._finite_values_only
def __str__(self) -> str:
raise NotImplementedError
def __repr__(self) -> str:
raise NotImplementedError
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
# opcheck to work.
def __len__(self) -> int:
raise TypeError
#
# Convenience Constructors
#
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
"Create a signed integer scalar type (size_bits includes sign-bit)."
return cls(size_bits - 1, size_bits, bias if bias else 0, True)
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
"""Create a unsigned integer scalar type."""
return cls(size_bits, size_bits, bias if bias else 0, False)
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType':
"""
Create a standard floating point type
(i.e. follows IEEE 754 conventions).
"""
return cls(exponent, mantissa, 0, True)
@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool,
nan_repr: int) -> 'ScalarType':
"""
Create a non-standard floating point type
(i.e. does not follow IEEE 754 conventions).
"""
return cls(exponent, mantissa, 0, True, finite_values_only,
nan_repr)
elif core_C_available:
try:
import vllm._core_C # noqa: F401
except ImportError as e:
logger.warning("Failed to import from vllm._core_C with %r", e)
ScalarType = torch.classes._core_C.ScalarType
if (hasattr(torch, "_library")
and hasattr(torch._library, "register_fake_class")):
# Needed for dynamo support of ScalarType.
@torch._library.register_fake_class("_core_C::ScalarType")
class FakeScalarType:
def __init__(self, scalar_type):
self.ScalarType = scalar_type
def bias_getter(self) -> int:
return self.ScalarType.bias
def exponent_getter(self) -> int:
return self.ScalarType.exponent
def mantissa_getter(self) -> int:
return self.ScalarType.mantissa
def signed_getter(self) -> bool:
return self.ScalarType.signed
def size_bits_getter(self) -> int:
return self.ScalarType.size_bits
@property
def size_bits(self) -> int:
return self.ScalarType.size_bits
def min(self) -> Union[int, float]:
return self.ScalarType.min()
def max(self) -> Union[int, float]:
return self.ScalarType.max()
def is_signed(self) -> bool:
return self.ScalarType.is_signed()
def is_floating_point(self) -> bool:
return self.ScalarType.is_floating_point()
def is_integer(self) -> bool:
return self.ScalarType.is_integer()
def has_bias(self) -> bool:
return self.ScalarType.has_bias()
def has_infs(self) -> bool:
return self.ScalarType.has_infs()
def has_nans(self) -> bool:
return self.ScalarType.has_nans()
def is_ieee_754(self) -> bool:
return self.ScalarType.is_ieee_754()
def __str__(self) -> str:
return self.ScalarType.__str__()
def __repr__(self) -> str:
return self.ScalarType.__repr__()
def __len__(self) -> int:
return self.ScalarType.__len__()
def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]:
return torch.classes._core_C.ScalarType.__obj_flatten__(
self.ScalarType)
@classmethod
def __obj_unflatten__(
cls, flat_type: Tuple[Tuple[str, Any],
...]) -> 'ScalarType':
return cls(
torch.classes._core_C.ScalarType.__obj_unflatten__(
flat_type))
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
return ScalarType.int_(size_bits, bias)
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType':
return ScalarType.uint(size_bits, bias)
@classmethod
def float_IEEE754(cls, exponent: int,
mantissa: int) -> 'ScalarType':
return ScalarType.float_IEEE754(exponent, mantissa)
@classmethod
def float_(cls, exponent: int, mantissa: int,
finite_values_only: bool,
nan_repr: int) -> 'ScalarType':
return ScalarType.float_(exponent, mantissa,
finite_values_only, nan_repr)
......@@ -6,9 +6,9 @@ import torch
import torch.library
import vllm.envs as envs
from vllm._core_ext import ScalarType
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType
try:
from lmslim import quant_ops
......@@ -31,7 +31,8 @@ with contextlib.suppress(ImportError):
import vllm._moe_C # noqa: F401
supports_moe_ops = True
if TYPE_CHECKING:
# neuron has torch version that doesn't even have impl_abstract
if TYPE_CHECKING or current_platform.is_neuron():
def register_fake(fn):
return lambda name: fn
......@@ -503,7 +504,7 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
workspace: torch.Tensor, b_q_type: ScalarType,
size_m: int, size_n: int, size_k: int) -> torch.Tensor:
return torch.ops._C.gptq_marlin_24_gemm(a, b_q_weight, b_meta, b_scales,
workspace, b_q_type, size_m,
workspace, b_q_type.id, size_m,
size_n, size_k)
......@@ -513,8 +514,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def _gptq_marlin_24_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_meta: torch.Tensor, b_scales: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType, size_m: int,
size_n: int, size_k: int) -> torch.Tensor:
b_q_type: ScalarType, size_m: torch.SymInt,
size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@register_fake("_C::gptq_marlin_gemm")
......@@ -526,17 +528,18 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
perm: torch.Tensor,
workspace: torch.Tensor,
b_q_type: ScalarType,
size_m: int,
size_n: int,
size_k: int,
size_m: torch.SymInt,
size_n: torch.SymInt,
size_k: torch.SymInt,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: int,
n: int) -> torch.Tensor:
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
m: torch.SymInt,
n: torch.SymInt) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device)
@register_fake("_C::ggml_mul_mat_vec_a8")
......@@ -544,7 +547,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
row: torch.SymInt,
) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device)
......@@ -553,7 +556,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
row: torch.SymInt,
) -> torch.Tensor:
batch = X.size(0)
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
......@@ -562,8 +565,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
s_tok: torch.Tensor, s_ch: torch.Tensor,
s_group: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
size_m: torch.SymInt, size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n),
dtype=torch.float16,
device=a.device)
......@@ -571,16 +574,16 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@register_fake("_C::marlin_gemm")
def _marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
size_m: torch.SymInt, size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n),
dtype=torch.float16,
device=a.device)
@register_fake("_C::awq_dequantize")
def _awq_dequantize_fake(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
zeros: torch.Tensor, split_k_iters: torch.SymInt,
thx: int, thy: int) -> torch.Tensor:
in_c = qweight.size(0)
qout_c = qweight.size(1)
out_c = qout_c * 8
......@@ -591,7 +594,7 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@register_fake("_C::awq_gemm")
def _awq_gemm_fake(input: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor, scales: torch.Tensor,
split_k_iters: int) -> torch.Tensor:
split_k_iters: torch.SymInt) -> torch.Tensor:
num_in_feats = input.size(0)
return torch.empty((split_k_iters, num_in_feats, qweight.size(1) * 8),
dtype=input.dtype,
......@@ -626,8 +629,9 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
@register_fake("_C::fp8_marlin_gemm")
def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
b_scales: torch.Tensor, workspace: torch.Tensor,
num_bits: int, size_m: int, size_n: int,
size_k: int) -> torch.Tensor:
num_bits: int, size_m: torch.SymInt,
size_n: torch.SymInt,
size_k: torch.SymInt) -> torch.Tensor:
return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device)
@register_fake("_C::machete_gemm")
......@@ -654,40 +658,6 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
return torch.empty_like(b_q_weight,
memory_format=torch.contiguous_format)
@register_fake("_C::causal_conv1d_fwd")
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
conv_states: Optional[torch.Tensor],
cu_seq_len: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
silu_activation: bool, pad_slot_id: int):
return None
@register_fake("_C::causal_conv1d_update")
def causal_conv1d_update_fake(x: torch.Tensor, conv_state: torch.Tensor,
weight: torch.Tensor,
bias_: Optional[torch.Tensor],
silu_activation: bool,
cache_seqlens: Optional[torch.Tensor],
conv_state_indices: Optional[torch.Tensor],
pad_slot_id: int) -> None:
return None
@register_fake("_C::selective_scan_fwd")
def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
A: torch.Tensor, B: torch.Tensor,
C: torch.Tensor, D_: Optional[torch.Tensor],
z_: Optional[torch.Tensor],
delta_bias_: Optional[torch.Tensor],
delta_softplus: bool,
cu_seq_len: Optional[torch.Tensor],
cache_indices: Optional[torch.Tensor],
has_initial_state: Optional[torch.Tensor],
ssm_states: Optional[torch.Tensor],
pad_slot_id: int) -> None:
return None
# cutlass
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
......@@ -828,7 +798,7 @@ def gptq_marlin_gemm(a: torch.Tensor,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
g_idx, perm, workspace, b_q_type,
g_idx, perm, workspace, b_q_type.id,
size_m, size_n, size_k, is_k_full,
has_zp, use_fp32_reduce)
......@@ -844,7 +814,7 @@ def fp8_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# machete
def machete_supported_schedules(b_type: ScalarType) -> List[str]:
return torch.ops._C.machete_supported_schedules(b_type)
return torch.ops._C.machete_supported_schedules(b_type.id)
def machete_gemm(
......@@ -859,13 +829,13 @@ def machete_gemm(
beta: Optional[float] = None,
schedule: Optional[str] = None,
) -> torch.Tensor:
return torch.ops._C.machete_gemm(a, b_q, b_type, b_scales, b_zeros,
return torch.ops._C.machete_gemm(a, b_q, b_type.id, b_scales, b_zeros,
b_group_size, c, alpha, beta, schedule)
def machete_prepack_B(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor:
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
return torch.ops._C.machete_prepack_B(b_q_weight, b_type.id)
if hasattr(torch.ops._C, "permute_cols"):
......@@ -1079,10 +1049,10 @@ if supports_moe_ops and hasattr(torch.ops._moe_C, "marlin_gemm_moe"):
topk_ids: torch.Tensor, b_scales: torch.Tensor,
b_zero_points: torch.Tensor, g_idx: torch.Tensor,
perm: torch.Tensor, workspace: torch.Tensor,
b_q_type: ScalarType, size_m: int, size_n: int,
size_k: int, is_k_full: bool, num_experts: int,
topk: int, moe_block_size: int,
replicate_input: bool,
b_q_type: ScalarType, size_m: torch.SymInt,
size_n: torch.SymInt, size_k: torch.SymInt,
is_k_full: bool, num_experts: int, topk: int,
moe_block_size: int, replicate_input: bool,
apply_weights: bool) -> torch.Tensor:
return torch.empty((size_m, topk, size_n),
dtype=a.dtype,
......
......@@ -15,8 +15,11 @@ if TYPE_CHECKING:
class AttentionType(Enum):
DECODER = auto() # Decoder attention between previous layer Q/K/V
ENCODER = auto() # Encoder attention between previous layer Q/K/V
ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V
ENCODER = auto(
) # Encoder attention between previous layer Q/K/V for encoder-decoder
ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V
ENCODER_DECODER = auto(
) # Attention between dec. Q and enc. K/V for encoder-decoder
class AttentionBackend(ABC):
......
......@@ -32,7 +32,7 @@ class FlashAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "flash-attn"
return "FLASH_ATTN"
@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
......@@ -524,8 +524,8 @@ class FlashAttentionImpl(AttentionImpl):
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.sliding_window = ((sliding_window, sliding_window)
if sliding_window is not None else (-1, -1))
self.sliding_window = ((sliding_window - 1,
0) if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
......@@ -535,12 +535,6 @@ class FlashAttentionImpl(AttentionImpl):
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if sliding_window is not None:
# NOTE(woosuk): flash-attn's sliding window does not work with
# paged KV cache.
raise ValueError(
"Sliding window is not supported in FlashAttention.")
support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
if head_size not in support_head_sizes:
raise ValueError(
......@@ -704,6 +698,7 @@ def unified_flash_attention(
max_seqlen_k=max_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
block_table=prefill_meta.block_tables,
softcap=logits_soft_cap,
......@@ -725,6 +720,7 @@ def unified_flash_attention(
max_seqlen_k=decode_meta.max_decode_seq_len,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
block_table=decode_meta.block_tables,
......@@ -739,6 +735,7 @@ def unified_flash_attention(
cache_seqlens=decode_meta.seq_lens_tensor,
softmax_scale=softmax_scale,
causal=True,
window_size=window_size,
alibi_slopes=alibi_slopes,
softcap=logits_soft_cap,
).squeeze(1)
......
......@@ -17,6 +17,7 @@ except ImportError:
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata,
......@@ -39,7 +40,7 @@ class FlashInferBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "flashinfer"
return "FLASHINFER"
@staticmethod
def get_impl_cls() -> Type["FlashInferImpl"]:
......@@ -124,7 +125,8 @@ class FlashInferState(AttentionState):
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = num_qo_heads // num_kv_heads > 4
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
......@@ -183,7 +185,8 @@ class FlashInferState(AttentionState):
self.runner.parallel_config))
num_kv_heads = self.runner.model_config.get_num_kv_heads(
self.runner.parallel_config)
use_tensor_cores = num_qo_heads // num_kv_heads > 4
use_tensor_cores = envs.VLLM_FLASHINFER_FORCE_TENSOR_CORES or (
num_qo_heads // num_kv_heads > 4)
self._graph_decode_wrapper = \
CUDAGraphBatchDecodeWithPagedKVCacheWrapper(
self._graph_decode_workspace_buffer, _indptr_buffer,
......
......@@ -19,7 +19,7 @@ class IpexAttnBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "ipex-attn"
return "IPEX"
@staticmethod
def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
......
......@@ -38,7 +38,7 @@ class OpenVINOAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "openvino"
return "OPENVINO"
@staticmethod
def get_impl_cls():
......
......@@ -11,6 +11,10 @@ from vllm.attention.backends.utils import CommonAttentionState
class PallasAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "PALLAS"
@staticmethod
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl
......
......@@ -20,7 +20,7 @@ class PlaceholderAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "placeholder-attn"
return "NO_ATTENTION"
@staticmethod
def get_impl_cls() -> Type["PlaceholderAttentionImpl"]:
......
......@@ -28,7 +28,7 @@ class ROCmFlashAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "rocm-flash-attn"
return "ROCM_FLASH"
@staticmethod
def get_impl_cls() -> Type["ROCmFlashAttentionImpl"]:
......
......@@ -10,9 +10,9 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
from vllm.utils import is_cpu
from vllm.platforms import current_platform
if is_cpu():
if current_platform.is_cpu():
try:
from vllm.attention.ops.ipex_attn import PagedAttention
except ImportError:
......@@ -25,7 +25,7 @@ class TorchSDPABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "torch-sdpa"
return "TORCH_SDPA"
@staticmethod
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
......@@ -234,10 +234,10 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
on the type of attention operation.
Decoder attn -> select entirely decoder self-attention-related fields
Encoder/decoder cross-attn -> select encoder sequence lengths &
Encoder/decoder cross-attn -> select encoder sequence lengths &
cross-attn block-tables fields
Encoder attn -> select encoder sequence lengths fields & no block tables
Arguments:
* attn_metadata: Attention metadata structure associated with attention
......
......@@ -317,8 +317,8 @@ class CommonAttentionState(AttentionState):
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
assert self.runner.attn_backend.get_name() == "XFORMERS", \
f"Expected attn_backend name to be 'XFORMERS', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
......@@ -337,8 +337,8 @@ class CommonAttentionState(AttentionState):
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
assert self.runner.attn_backend.get_name() == "XFORMERS", \
f"Expected attn_backend name to be 'XFORMERS', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
......@@ -356,8 +356,8 @@ class CommonAttentionState(AttentionState):
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
assert self.runner.attn_backend.get_name() == "XFORMERS", \
f"Expected attn_backend name to be 'XFORMERS', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)
......
......@@ -24,7 +24,7 @@ class XFormersBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "xformers"
return "XFORMERS"
@staticmethod
def get_impl_cls() -> Type["XFormersImpl"]:
......@@ -287,13 +287,15 @@ def _get_attn_bias(
* Appropriate attention bias value given the attention type
'''
if attn_type == AttentionType.DECODER:
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
return attn_metadata.attn_bias
elif attn_type == AttentionType.ENCODER:
return attn_metadata.encoder_attn_bias
else:
# attn_type == AttentionType.ENCODER_DECODER
elif attn_type == AttentionType.ENCODER_DECODER:
return attn_metadata.cross_attn_bias
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
def _set_attn_bias(
......@@ -313,7 +315,8 @@ def _set_attn_bias(
encoder/decoder cross-attention
'''
if attn_type == AttentionType.DECODER:
if (attn_type == AttentionType.DECODER
or attn_type == AttentionType.ENCODER_ONLY):
attn_metadata.attn_bias = attn_bias
elif attn_type == AttentionType.ENCODER:
attn_metadata.encoder_attn_bias = attn_bias
......@@ -371,6 +374,12 @@ def _get_seq_len_block_table_args(
# No block tables associated with encoder attention
return (attn_metadata.encoder_seq_lens_tensor,
attn_metadata.max_encoder_seq_len, None)
elif attn_type == AttentionType.ENCODER_ONLY:
assert is_prompt, "Should not have decode for encoder only model."
# No block tables associated with encoder attention
return (attn_metadata.seq_lens_tensor,
attn_metadata.max_prefill_seq_len, None)
else:
raise AttributeError(f"Invalid attention type {str(attn_type)}")
......@@ -479,7 +488,10 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
* ENCODER: no KV caching; pass encoder sequence
attributes (encoder_seq_lens/encoder_seq_lens_tensor/
max_encoder_seq_len) to kernel, in lieu of decoder
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len)
sequence attributes (seq_lens/seq_lens_tensor/max_seq_len).
Used for encoder branch of encoder-decoder models.
* ENCODER_ONLY: no kv_caching, uses the normal attention
attributes (seq_lens/seq_lens_tensor/max_seq_len).
* ENCODER_DECODER: cross-attention behavior;
use cross-attention block table for caching KVs derived
from encoder hidden states; since KV sequence lengths
......@@ -509,6 +521,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
and (not attn_metadata.is_all_encoder_attn_metadata_set)):
raise AttributeError("Encoder attention requires setting "
"encoder metadata attributes.")
elif (attn_type == AttentionType.ENCODER_DECODER
and (not attn_metadata.is_all_cross_attn_metadata_set)):
raise AttributeError("Encoder/decoder cross-attention "
......@@ -609,6 +622,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
assert out.shape == output[:num_prefill_tokens].shape
output[:num_prefill_tokens] = out
else:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have prefix attention.")
assert prefill_meta.query_start_loc is not None
assert prefill_meta.max_query_len is not None
......@@ -638,6 +653,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
output[:num_prefill_tokens] = out
if decode_meta := attn_metadata.decode_metadata:
assert attn_type != AttentionType.ENCODER_ONLY, (
"Encoder-only models should not have decode metadata.")
(
seq_lens_arg,
......@@ -703,36 +720,60 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
None, :].expand(value.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])
# Set attention bias if not provided. This typically happens at
# the very attention layer of every iteration.
# FIXME(woosuk): This is a hack.
attn_bias = _get_attn_bias(attn_metadata, attn_type)
if attn_bias is None:
if self.alibi_slopes is None:
# Cross attention block of decoder branch of encoder-decoder
# model uses seq_lens for dec / encoder_seq_lens for enc
if (attn_type == AttentionType.ENCODER_DECODER):
assert attn_metadata.seq_lens is not None
assert attn_metadata.encoder_seq_lens is not None
# Default enc/dec cross-attention mask is non-causal
# Cross-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.seq_lens, attn_metadata.encoder_seq_lens)
# Encoder branch of encoder-decoder model uses
# attn_metadata.encoder_seq_lens
elif attn_type == AttentionType.ENCODER:
assert attn_metadata.encoder_seq_lens is not None
# Default encoder self-attention mask is non-causal
# Encoder self-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.encoder_seq_lens)
else:
# Self-attention block of encoder-only model just
# uses the seq_lens directly.
elif attn_type == AttentionType.ENCODER_ONLY:
assert attn_metadata.seq_lens is not None
# Default decoder self-attention mask is causal
# Encoder self-attention mask is non-causal
attn_bias = BlockDiagonalMask.from_seqlens(
attn_metadata.seq_lens)
# Self-attention block of decoder branch just
# uses the seq_lens directly
elif attn_type == AttentionType.DECODER:
assert attn_metadata.seq_lens is not None
# Decoder self-attention mask is causal
attn_bias = BlockDiagonalCausalMask.from_seqlens(
attn_metadata.seq_lens)
else:
raise ValueError("Unknown AttentionType: %s", attn_type)
if self.sliding_window is not None:
attn_bias = attn_bias.make_local_attention(
self.sliding_window)
attn_bias = [attn_bias]
else:
assert attn_type == AttentionType.DECODER
assert attn_metadata.seq_lens is not None
attn_bias = _make_alibi_bias(self.alibi_slopes,
self.num_kv_heads, query.dtype,
......
......@@ -78,10 +78,9 @@ class Attention(nn.Module):
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size, sliding_window, dtype,
kv_cache_dtype, block_size,
is_attention_free, blocksparse_params
is not None)
attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype,
block_size, is_attention_free,
blocksparse_params is not None)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
......
......@@ -3,7 +3,7 @@ import math
import torch
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip
from vllm.utils import is_hip
from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask)
......@@ -32,7 +32,7 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
):
super().__init__()
if use_spda is None:
use_spda = is_hip() or is_cpu() or not \
use_spda = is_hip() or current_platform.is_cpu() or not \
IS_COMPUTE_8_OR_ABOVE
device = device or (torch.cuda.current_device()
if current_platform.is_cuda_alike() else "cpu")
......@@ -109,13 +109,13 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
Support grouped attention, with `q[:, i*r:(i*r + r)]`
is correspondent to `k[:, i]`, where `r` is the q/k ratio.
cu_seqlens_k: shape=(batch_size + 1,),
indicating segment of samples,
cu_seqlens_k: shape=(batch_size + 1,),
indicating segment of samples,
e.g., `k[cu_seqlen[i]:cu_seqlne[i+1]]` is q of sample i
cu_seqlens_q: shape=(batch_size + 1, ).
Default None: same as cu_seqlens_k for prefilling or
[0, 1, .., batch_size] for decoding.
The only case you need to specify is when q is a mix of
The only case you need to specify is when q is a mix of
prefilling and decoding.
sm_scale: softmax scale, default to 1/sqrt(head_size).
......@@ -171,7 +171,7 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
def spda(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
"""For CPU, V100 or other older GPUs.
NOTE: torch SPDA supports nested tensor,
NOTE: torch SPDA supports nested tensor,
but seems extremely slow. Choose to pad instead.
"""
assert (cu_seqlens_q is None or
......@@ -201,8 +201,8 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
return self.transpose_and_unpad(spda_output, cu_seqlens)
def forward(self, q, k, v, cu_seqlens_k, cu_seqlens_q=None, sm_scale=None):
"""Dispatch to `varlen_attn` (Ampere or newer) or
`self.spda`(cpu, Volta, Turing or older)based on
"""Dispatch to `varlen_attn` (Ampere or newer) or
`self.spda`(cpu, Volta, Turing or older)based on
the type of device used and cuda compute capability.
q, k, v: shape = (num_tokens, num_heads_q/kv, head_size).
......@@ -213,8 +213,8 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
cu_seqlens_q: shape=(batch_size + 1, ).
Default None: same as cu_seqlens_k for prefilling or
[0, 1, .., batch_size] for decoding.
The only case you need to specify
is when q is a mix of prefilling
The only case you need to specify
is when q is a mix of prefilling
and decoding.
sm_scale: softmax scale, default to 1/sqrt(head_size).
......
......@@ -10,13 +10,14 @@ import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import STR_BACKEND_ENV_VAR, is_cpu, is_hip, is_openvino, is_xpu
from vllm.utils import STR_BACKEND_ENV_VAR, is_hip, is_openvino, is_xpu
logger = init_logger(__name__)
class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
......@@ -90,7 +91,6 @@ def get_global_forced_attn_backend() -> Optional[_Backend]:
@lru_cache(maxsize=None)
def get_attn_backend(
head_size: int,
sliding_window: Optional[int],
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
......@@ -105,12 +105,16 @@ def get_attn_backend(
BlocksparseFlashAttentionBackend)
return BlocksparseFlashAttentionBackend
backend = which_attn_to_use(head_size, sliding_window, dtype,
kv_cache_dtype, block_size, is_attention_free)
backend = which_attn_to_use(head_size, dtype, kv_cache_dtype, block_size,
is_attention_free)
if backend == _Backend.FLASH_ATTN:
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
return FlashAttentionBackend
if backend == _Backend.FLASH_ATTN_VLLM_V1:
from vllm.v1.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend as FlashAttentionBackendV1)
return FlashAttentionBackendV1
if backend == _Backend.XFORMERS:
logger.info("Using XFormers backend.")
from vllm.attention.backends.xformers import ( # noqa: F401
......@@ -122,7 +126,7 @@ def get_attn_backend(
ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend
elif backend == _Backend.TORCH_SDPA:
assert is_cpu(), RuntimeError(
assert current_platform.is_cpu(), RuntimeError(
"Torch SDPA backend is only used for the CPU device.")
logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend
......@@ -155,7 +159,6 @@ def get_attn_backend(
def which_attn_to_use(
head_size: int,
sliding_window: Optional[int],
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
......@@ -185,7 +188,7 @@ def which_attn_to_use(
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if is_cpu():
if current_platform.is_cpu():
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA
......@@ -217,6 +220,9 @@ def which_attn_to_use(
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH
if envs.VLLM_USE_V1:
return _Backend.FLASH_ATTN_VLLM_V1
# FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN:
if not current_platform.has_device_capability(80):
......@@ -243,10 +249,6 @@ def which_attn_to_use(
"Cannot use FlashAttention-2 backend for block size not "
"divisible by 16.")
selected_backend = _Backend.XFORMERS
elif sliding_window is not None:
logger.info(
"Cannot use FlashAttention-2 backend due to sliding window.")
selected_backend = _Backend.XFORMERS
# FlashAttn is valid for the model, checking if the package is installed.
if selected_backend == _Backend.FLASH_ATTN:
......
from dataclasses import dataclass
from typing import List, Optional
from typing import Dict, List, Optional
from vllm.sequence import Logprob
@dataclass
......@@ -11,6 +13,7 @@ class BeamSearchSequence:
"""
# The tokens includes the prompt.
tokens: List[int]
logprobs: List[Dict[int, Logprob]]
cum_logprob: float = 0.0
text: Optional[str] = None
......@@ -28,7 +31,7 @@ class BeamSearchInstance:
def __init__(self, prompt_tokens: List[int]):
self.beams: List[BeamSearchSequence] = [
BeamSearchSequence(tokens=prompt_tokens)
BeamSearchSequence(tokens=prompt_tokens, logprobs=[])
]
self.completed: List[BeamSearchSequence] = []
......
"""Benchmark offline inference throughput."""
import argparse
import dataclasses
import json
import random
import time
......@@ -12,10 +13,9 @@ from tqdm import tqdm
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.entrypoints.openai.api_server import (
build_async_engine_client_from_engine_args)
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.sampling_params import BeamSearchParams
from vllm.utils import FlexibleArgumentParser, merge_async_iterators
......@@ -69,53 +69,11 @@ def sample_requests(
def run_vllm(
warmup_requests: List[Tuple[str, int, int]],
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
engine_args: EngineArgs,
) -> float:
from vllm import LLM, SamplingParams
llm = LLM(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
disable_async_output_proc=disable_async_output_proc,
)
llm = LLM(**dataclasses.asdict(engine_args))
# Add the requests to the engine.
prompts: List[str] = []
......@@ -192,56 +150,11 @@ def run_vllm(
async def run_vllm_async(
requests: List[Tuple[str, int, int]],
model: str,
tokenizer: str,
quantization: Optional[str],
tensor_parallel_size: int,
seed: int,
n: int,
trust_remote_code: bool,
dtype: str,
max_model_len: Optional[int],
enforce_eager: bool,
kv_cache_dtype: str,
quantization_param_path: Optional[str],
device: str,
enable_prefix_caching: bool,
enable_chunked_prefill: bool,
max_num_batched_tokens: int,
distributed_executor_backend: Optional[str],
gpu_memory_utilization: float = 0.9,
num_scheduler_steps: int = 1,
download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False,
engine_args: AsyncEngineArgs,
disable_frontend_multiprocessing: bool = False,
) -> float:
from vllm import SamplingParams
engine_args = AsyncEngineArgs(
model=model,
tokenizer=tokenizer,
quantization=quantization,
tensor_parallel_size=tensor_parallel_size,
seed=seed,
trust_remote_code=trust_remote_code,
dtype=dtype,
max_model_len=max_model_len,
gpu_memory_utilization=gpu_memory_utilization,
enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype,
quantization_param_path=quantization_param_path,
device=device,
enable_prefix_caching=enable_prefix_caching,
download_dir=download_dir,
enable_chunked_prefill=enable_chunked_prefill,
max_num_batched_tokens=max_num_batched_tokens,
distributed_executor_backend=distributed_executor_backend,
load_format=load_format,
num_scheduler_steps=num_scheduler_steps,
disable_async_output_proc=disable_async_output_proc,
worker_use_ray=False,
disable_log_requests=True,
)
async with build_async_engine_client_from_engine_args(
engine_args, disable_frontend_multiprocessing) as llm:
......@@ -360,7 +273,16 @@ def main(args: argparse.Namespace):
for _ in range(1)]
if args.dataset is None:
# Synthesize a prompt with the given input length.
prompt = "hi" * (args.input_len - 1)
# As tokenizer may add additional tokens like BOS, we need to try
# different lengths to get the desired input length.
for i in range(-10, 10):
prompt = "hi " * (args.input_len + i)
tokenized_prompt = tokenizer(prompt).input_ids
if len(tokenized_prompt) == args.input_len:
break
else:
raise ValueError(
f"Failed to synthesize a prompt with {args.input_len} tokens.")
requests = [(prompt, args.input_len, args.output_len)
for _ in range(args.num_prompts)]
else:
......@@ -369,35 +291,16 @@ def main(args: argparse.Namespace):
if args.backend == "vllm":
if args.async_engine:
run_args = [
requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.download_dir, args.load_format, args.disable_async_output_proc
]
else:
run_args = [
warmup_requests, requests, args.model, args.tokenizer, args.quantization,
args.tensor_parallel_size, args.seed, args.n,
args.trust_remote_code, args.dtype, args.max_model_len,
args.enforce_eager, args.kv_cache_dtype,
args.quantization_param_path, args.device,
args.enable_prefix_caching, args.enable_chunked_prefill,
args.max_num_batched_tokens, args.distributed_executor_backend,
args.gpu_memory_utilization, args.num_scheduler_steps,
args.download_dir, args.load_format, args.disable_async_output_proc
]
if args.async_engine:
run_args.append(args.disable_frontend_multiprocessing)
elapsed_time = uvloop.run(run_vllm_async(*run_args))
elapsed_time = uvloop.run(
run_vllm_async(
requests,
args.n,
AsyncEngineArgs.from_cli_args(args),
args.disable_frontend_multiprocessing,
))
else:
elapsed_time = run_vllm(*run_args)
elapsed_time = run_vllm(warmup_requests, requests, args.n,
EngineArgs.from_cli_args(args))
elif args.backend == "hf":
assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
......@@ -452,13 +355,6 @@ if __name__ == "__main__":
default=None,
help="Output length for each request. Overrides the "
"output length from the dataset.")
parser.add_argument("--model", type=str, default="facebook/opt-125m")
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=[*QUANTIZATION_METHODS, None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
type=int,
default=1,
......@@ -471,123 +367,15 @@ if __name__ == "__main__":
type=int,
default=1000,
help="Number of prompts to process.")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--hf-max-batch-size",
type=int,
default=None,
help="Maximum batch size for HF backend.")
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--max-model-len',
type=int,
default=None,
help='Maximum length of a sequence (including prompt and output). '
'If None, will be derived from the model.')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.9,
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument("--enforce-eager",
action="store_true",
help="enforce eager execution")
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument("--device",
type=str,
default="auto",
choices=DEVICE_OPTIONS,
help='device type for vLLM execution')
parser.add_argument(
"--num-scheduler-steps",
type=int,
default=1,
help="Maximum number of forward steps per scheduler call.")
parser.add_argument(
"--enable-prefix-caching",
action='store_true',
help="Enable automatic prefix caching for vLLM backend.")
parser.add_argument("--enable-chunked-prefill",
action='store_true',
help="enable chunked prefill for vLLM backend.")
parser.add_argument('--max-num-batched-tokens',
type=int,
default=None,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--download-dir',
type=str,
default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the throughput results in JSON format.')
parser.add_argument(
'--distributed-executor-backend',
choices=['ray', 'mp'],
default=None,
help='Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.')
parser.add_argument(
'--load-format',
type=str,
default=EngineArgs.load_format,
choices=[
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
'bitsandbytes'
],
help='The format of the model weights to load.\n\n'
'* "auto" will try to load the weights in the safetensors format '
'and fall back to the pytorch bin format if safetensors format '
'is not available.\n'
'* "pt" will load the weights in the pytorch bin format.\n'
'* "safetensors" will load the weights in the safetensors format.\n'
'* "npcache" will load the weights in pytorch format and store '
'a numpy cache to speed up the loading.\n'
'* "dummy" will initialize the weights with random values, '
'which is mainly for profiling.\n'
'* "tensorizer" will load the weights using tensorizer from '
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
'section for more information.\n'
'* "bitsandbytes" will load the weights using bitsandbytes '
'quantization.\n')
parser.add_argument(
"--disable-async-output-proc",
action='store_true',
default=False,
help="Disable async output processor for vLLM backend.")
parser.add_argument("--async-engine",
action='store_true',
default=False,
......@@ -596,6 +384,7 @@ if __name__ == "__main__":
action='store_true',
default=False,
help="Disable decoupled async engine frontend.")
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
if args.tokenizer is None:
args.tokenizer = args.model
......
__commit__ = "93ec62b8556e279d2c050bdc1c3247831bd39466"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment