Commit 6640dc0b authored by zhuwenwen's avatar zhuwenwen
Browse files
parents 44d4d334 83e4e0fe
...@@ -13,9 +13,10 @@ from vllm.pooling_params import PoolingParams ...@@ -13,9 +13,10 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.version import __dcu_version__ from vllm.version import __dcu_version__
__version__ = "0.5.0" from .version import __version__
__all__ = [ __all__ = [
"__version__",
"LLM", "LLM",
"ModelRegistry", "ModelRegistry",
"PromptStrictInputs", "PromptStrictInputs",
......
import contextlib import contextlib
import functools
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
import torch import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
try: try:
import vllm._C import vllm._C
except ImportError as e: except ImportError as e:
from vllm.logger import init_logger
logger = init_logger(__name__)
logger.warning("Failed to import from vllm._C with %r", e) logger.warning("Failed to import from vllm._C with %r", e)
with contextlib.suppress(ImportError): with contextlib.suppress(ImportError):
...@@ -23,6 +26,25 @@ def is_custom_op_supported(op_name: str) -> bool: ...@@ -23,6 +26,25 @@ def is_custom_op_supported(op_name: str) -> bool:
return op is not None return op is not None
def hint_on_error(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
try:
return fn(*args, **kwargs)
except AttributeError as e:
msg = (
"Error in calling custom op %s: %s\n"
"Possibly you have built or installed an obsolete version of vllm.\n"
"Please try a clean build and install of vllm,"
"or remove old built files such as vllm/*cpython*.so and build/ ."
)
logger.error(msg, fn.__name__, e)
raise e
return wrapper
# activation ops # activation ops
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
torch.ops._C.silu_and_mul(out, x) torch.ops._C.silu_and_mul(out, x)
...@@ -190,8 +212,8 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -190,8 +212,8 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# cutlass # cutlass
def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor, scale_b: torch.Tensor,
out_dtype: Type[torch.dtype]) -> torch.Tensor: out_dtype: Type[torch.dtype]) -> torch.Tensor:
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
...@@ -200,8 +222,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor, ...@@ -200,8 +222,7 @@ def cutlass_scaled_mm_dq(a: torch.Tensor, b: torch.Tensor,
n = b.shape[1] n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device) out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops._C.cutlass_scaled_mm_dq(out, a, b, scale_a, scale_b) torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b)
return out return out
...@@ -459,3 +480,25 @@ def dispatch_bgmv_low_level( ...@@ -459,3 +480,25 @@ def dispatch_bgmv_low_level(
h_out, h_out,
y_offset, y_offset,
) )
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values = globals()
names_and_values_to_update = {}
# prepare variables to avoid dict size change during iteration
k, v, arg = None, None, None
fn_type = type(lambda x: x)
for k, v in names_and_values.items():
# find functions that are defined in this file and have torch.Tensor
# in their annotations. `arg == "torch.Tensor"` is used to handle
# the case when users use `import __annotations__` to turn type
# hints into strings.
if isinstance(v, fn_type) \
and v.__code__.co_filename == __file__ \
and any(arg is torch.Tensor or arg == "torch.Tensor"
for arg in v.__annotations__.values()):
names_and_values_to_update[k] = hint_on_error(v)
names_and_values.update(names_and_values_to_update)
del names_and_values_to_update, names_and_values, v, k, fn_type
...@@ -317,7 +317,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -317,7 +317,7 @@ class FlashAttentionImpl(AttentionImpl):
# normal attention # normal attention
# When block_tables are not filled, it means q and k are the # When block_tables are not filled, it means q and k are the
# prompt, and they have the same length. # prompt, and they have the same length.
flash_attn_varlen_func( out = flash_attn_varlen_func(
q=query, q=query,
k=key, k=key,
v=value, v=value,
...@@ -329,13 +329,14 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -329,13 +329,14 @@ class FlashAttentionImpl(AttentionImpl):
causal=True, causal=True,
window_size=self.sliding_window, window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
out=output[:num_prefill_tokens],
) )
assert output[:num_prefill_tokens].shape == out.shape
output[:num_prefill_tokens] = out
else: else:
# prefix-enabled attention # prefix-enabled attention
assert prefill_meta.seq_lens is not None assert prefill_meta.seq_lens is not None
max_seq_len = max(prefill_meta.seq_lens) max_seq_len = max(prefill_meta.seq_lens)
flash_attn_varlen_func( output[:num_prefill_tokens] = flash_attn_varlen_func(
q=query, q=query,
k=key_cache, k=key_cache,
v=value_cache, v=value_cache,
...@@ -347,12 +348,11 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -347,12 +348,11 @@ class FlashAttentionImpl(AttentionImpl):
causal=True, causal=True,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
block_table=prefill_meta.block_tables, block_table=prefill_meta.block_tables,
out=output[:num_prefill_tokens],
) )
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
flash_attn_with_kvcache( output[num_prefill_tokens:] = flash_attn_with_kvcache(
decode_query.unsqueeze(1), decode_query.unsqueeze(1),
key_cache, key_cache,
value_cache, value_cache,
...@@ -361,8 +361,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -361,8 +361,7 @@ class FlashAttentionImpl(AttentionImpl):
softmax_scale=self.scale, softmax_scale=self.scale,
causal=True, causal=True,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
out=output[num_prefill_tokens:].unsqueeze(1), ).squeeze(1)
)
# Reshape the output tensor. # Reshape the output tensor.
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
import torch_xla.experimental.custom_kernel # Required to register custom ops.
import torch_xla.experimental.dynamo_set_buffer_donor
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata)
class PallasAttentionBackend(AttentionBackend):
@staticmethod
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
return PallasAttentionBackendImpl
@staticmethod
def make_metadata(*args, **kwargs) -> "PallasMetadata":
return PallasMetadata(*args, **kwargs)
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return (num_kv_heads, num_blocks, block_size, head_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
) -> None:
raise NotImplementedError("swap_blocks is not implemented.")
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
) -> None:
# TODO(woosuk): Implement this.
raise NotImplementedError("copy_blocks is not implemented.")
@dataclass
class PallasMetadata(AttentionMetadata):
# Currently, input sequences can only contain all prefills
# or all decoding.
block_tables: Optional[torch.Tensor]
context_lens: Optional[torch.Tensor]
@property
def prefill_metadata(self) -> Optional["PallasMetadata"]:
if self.num_prefills == 0:
return None
assert self.num_decode_tokens == 0
assert self.block_tables is None
assert self.context_lens is None
return self
@property
def decode_metadata(self) -> Optional["PallasMetadata"]:
if self.num_decode_tokens == 0:
return None
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.block_tables is not None
assert self.context_lens is not None
return self
class PallasAttentionBackendImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.")
if alibi_slopes is not None:
raise NotImplementedError("Alibi slopes is not supported.")
if sliding_window is not None:
raise NotImplementedError("Sliding window is not supported.")
if kv_cache_dtype != "auto":
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
if torch_xla.tpu.version() < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
self.megacore_mode = None
tpu_type = torch_xla.tpu.get_tp_groupu_env()["TYPE"].lower()
if not tpu_type.endswith("lite"):
if self.num_kv_heads % 2 == 0:
self.megacore_mode = "kv_head"
else:
# NOTE(woosuk): If the batch size is not a multiple of 2, the
# megacore mode will be None.
self.megacore_mode = "batch"
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: Tuple[Optional[torch.Tensor], Optional[torch.Tensor]],
attn_metadata: PallasMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
key_cache = [num_kv_heads, num_blocks, block_size, head_size]
value_cache = [num_kv_heads, num_blocks, block_size, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
"""
assert kv_scale == 1.0
batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
value = value.view(batch_size, seq_len, self.num_kv_heads,
self.head_size)
if kv_cache[0] is not None:
slot_mapping = attn_metadata.slot_mapping
key_cache, value_cache = kv_cache
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
query = query * self.scale
if attn_metadata.num_prefills > 0:
assert seq_len % 16 == 0, (
"Pallas FlashAttention kernel requires seq_len to be a "
f"multiple of 16 but got {seq_len}")
# Handle GQA/MQA.
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=-2)
key = key.view(batch_size, seq_len, self.num_heads,
self.head_size)
value = value.repeat_interleave(self.num_queries_per_kv,
dim=-2)
value = value.view(batch_size, seq_len, self.num_heads,
self.head_size)
# FlashAttention requires [batch_size, num_heads, seq_len, d_model]
# while the input is [batch_size, seq_len, num_heads, d_model].
# Permute the input to match the required format.
output = torch.ops.xla.flash_attention(
query.permute(0, 2, 1, 3),
key.permute(0, 2, 1, 3),
value.permute(0, 2, 1, 3),
True,
)
output = output.permute(0, 2, 1, 3)
else:
# Decoding run.
assert kv_cache is not None
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
if self.megacore_mode == "batch" and batch_size % 2 != 0:
megacore_mode = None
else:
megacore_mode = self.megacore_mode
# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query.squeeze(dim=1),
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
)
else:
output = torch.ops.xla.paged_attention(
query.squeeze(dim=1),
key_cache,
value_cache,
attn_metadata.context_lens,
attn_metadata.block_tables,
pages_per_compute_block,
)
# Reshape the output tensor.
return output.reshape(batch_size, seq_len, hidden_size)
def write_to_kv_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
) -> None:
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
key = key.flatten(0, 2)
value = value.flatten(0, 2)
key_cache = key_cache.flatten(0, 2)
value_cache = value_cache.flatten(0, 2)
key_cache.index_copy_(0, slot_mapping, key)
value_cache.index_copy_(0, slot_mapping, value)
...@@ -8,8 +8,16 @@ from torch.nn.functional import scaled_dot_product_attention ...@@ -8,8 +8,16 @@ from torch.nn.functional import scaled_dot_product_attention
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata) AttentionMetadata)
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import PagedAttentionMetadata
PagedAttentionMetadata) from vllm.utils import is_cpu
if is_cpu():
try:
from vllm.attention.ops.ipex_attn import PagedAttention
except ImportError:
from vllm.attention.ops.paged_attn import PagedAttention
else:
from vllm.attention.ops.paged_attn import PagedAttention
class TorchSDPABackend(AttentionBackend): class TorchSDPABackend(AttentionBackend):
...@@ -197,13 +205,14 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -197,13 +205,14 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
attn_metadata.attn_bias): attn_metadata.attn_bias):
end = start + seq_len end = start + seq_len
sub_out = scaled_dot_product_attention( sub_out = scaled_dot_product_attention(
query[:, start:end, :], query[None, :, start:end, :],
key[:, start:end, :], key[None, :, start:end, :],
value[:, start:end, :], value[None, :, start:end, :],
attn_mask=mask, attn_mask=mask,
dropout_p=0.0, dropout_p=0.0,
is_causal=not self.need_mask, is_causal=not self.need_mask,
scale=self.scale).movedim(query.dim() - 2, 0) scale=self.scale).squeeze(0).movedim(
query.dim() - 2, 0)
output[start:end, :, :] = sub_out output[start:end, :, :] = sub_out
start = end start = end
else: else:
...@@ -248,7 +257,7 @@ def _make_alibi_bias( ...@@ -248,7 +257,7 @@ def _make_alibi_bias(
num_heads = alibi_slopes.shape[0] num_heads = alibi_slopes.shape[0]
bias = bias[None, :].repeat((num_heads, 1, 1)) bias = bias[None, :].repeat((num_heads, 1, 1))
bias.mul_(alibi_slopes[:, None, None]) bias.mul_(alibi_slopes[:, None, None]).unsqueeze_(0)
inf_mask = torch.empty( inf_mask = torch.empty(
(1, seq_len, seq_len), (1, seq_len, seq_len),
dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1) dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1)
......
from typing import Dict, List, Optional, Tuple
import intel_extension_for_pytorch.llm.modules as ipex_modules
import torch
from vllm import _custom_ops as ops
class PagedAttention:
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [64, 80, 96, 112, 128, 256]
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int,
head_size: int,
*args,
) -> Tuple[int, ...]:
return (2, num_blocks, block_size * num_kv_heads * head_size)
@staticmethod
def split_kv_cache(
kv_cache: torch.Tensor,
num_kv_heads: int,
head_size: int,
*args,
) -> Tuple[torch.Tensor, torch.Tensor]:
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, -1, head_size)
return key_cache, value_cache
@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
kv_scale: float,
*args,
) -> None:
ipex_modules.PagedAttention.reshape_and_cache(
key, value, key_cache, value_cache,
slot_mapping.flatten().int())
@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
max_context_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
kv_scale: float,
*args,
) -> torch.Tensor:
output = torch.empty_like(query)
block_size = value_cache.shape[2]
head_mapping = torch.arange(
0,
num_kv_heads,
device="cpu",
dtype=torch.int32,
).view(num_kv_heads,
1).repeat_interleave(query.size(1) // num_kv_heads).flatten()
ipex_modules.PagedAttention.single_query_cached_kv_attention(
output, query.contiguous(), key_cache, value_cache, head_mapping,
scale, block_tables, context_lens, block_size, max_context_len,
alibi_slopes)
return output
@staticmethod
def forward_prefix(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
subquery_start_loc: torch.Tensor,
prompt_lens_tensor: torch.Tensor,
context_lens: torch.Tensor,
max_subquery_len: int,
alibi_slopes: Optional[torch.Tensor],
*args,
) -> torch.Tensor:
raise NotImplementedError
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: Dict[int, int],
*args,
) -> None:
raise NotImplementedError
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: Dict[int, List[int]],
*args,
) -> None:
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
...@@ -213,7 +213,7 @@ def _attn_fwd_inner( ...@@ -213,7 +213,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 256, "BLOCK_M": 256,
"BLOCK_N": 64, "BLOCK_N": 64,
"waves_per_eu": 2, "waves_per_eu": 0,
"PRE_LOAD_V": False, "PRE_LOAD_V": False,
}, },
num_stages=1, num_stages=1,
...@@ -223,7 +223,7 @@ def _attn_fwd_inner( ...@@ -223,7 +223,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 128, "BLOCK_M": 128,
"BLOCK_N": 128, "BLOCK_N": 128,
"waves_per_eu": 2, "waves_per_eu": 0,
"PRE_LOAD_V": False, "PRE_LOAD_V": False,
}, },
num_stages=1, num_stages=1,
...@@ -233,7 +233,7 @@ def _attn_fwd_inner( ...@@ -233,7 +233,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 256, "BLOCK_M": 256,
"BLOCK_N": 128, "BLOCK_N": 128,
"waves_per_eu": 2, "waves_per_eu": 0,
"PRE_LOAD_V": False, "PRE_LOAD_V": False,
}, },
num_stages=1, num_stages=1,
...@@ -243,7 +243,7 @@ def _attn_fwd_inner( ...@@ -243,7 +243,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 128, "BLOCK_M": 128,
"BLOCK_N": 64, "BLOCK_N": 64,
"waves_per_eu": 1, "waves_per_eu": 0,
"PRE_LOAD_V": False, "PRE_LOAD_V": False,
}, },
num_stages=1, num_stages=1,
...@@ -253,7 +253,7 @@ def _attn_fwd_inner( ...@@ -253,7 +253,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 128, "BLOCK_M": 128,
"BLOCK_N": 64, "BLOCK_N": 64,
"waves_per_eu": 3, "waves_per_eu": 0,
"PRE_LOAD_V": True, "PRE_LOAD_V": True,
}, },
num_stages=1, num_stages=1,
...@@ -263,7 +263,7 @@ def _attn_fwd_inner( ...@@ -263,7 +263,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 128, "BLOCK_M": 128,
"BLOCK_N": 64, "BLOCK_N": 64,
"waves_per_eu": 3, "waves_per_eu": 0,
"PRE_LOAD_V": False, "PRE_LOAD_V": False,
}, },
num_stages=1, num_stages=1,
...@@ -273,7 +273,7 @@ def _attn_fwd_inner( ...@@ -273,7 +273,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 64, "BLOCK_M": 64,
"BLOCK_N": 64, "BLOCK_N": 64,
"waves_per_eu": 4, "waves_per_eu": 0,
"PRE_LOAD_V": False, "PRE_LOAD_V": False,
}, },
num_stages=1, num_stages=1,
...@@ -283,7 +283,7 @@ def _attn_fwd_inner( ...@@ -283,7 +283,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 32, "BLOCK_M": 32,
"BLOCK_N": 32, "BLOCK_N": 32,
"waves_per_eu": 4, "waves_per_eu": 0,
"PRE_LOAD_V": False, "PRE_LOAD_V": False,
}, },
num_stages=1, num_stages=1,
...@@ -296,7 +296,7 @@ def _attn_fwd_inner( ...@@ -296,7 +296,7 @@ def _attn_fwd_inner(
{ {
"BLOCK_M": 16, "BLOCK_M": 16,
"BLOCK_N": 16, "BLOCK_N": 16,
"waves_per_eu": 1, "waves_per_eu": 0,
"PRE_LOAD_V": False, "PRE_LOAD_V": False,
}, },
num_stages=1, num_stages=1,
......
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import is_cpu, is_hip from vllm.utils import is_cpu, is_hip, is_tpu
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -18,6 +18,7 @@ class _Backend(enum.Enum): ...@@ -18,6 +18,7 @@ class _Backend(enum.Enum):
ROCM_FLASH = enum.auto() ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto() TORCH_SDPA = enum.auto()
FLASHINFER = enum.auto() FLASHINFER = enum.auto()
PALLAS = enum.auto()
@lru_cache(maxsize=None) @lru_cache(maxsize=None)
...@@ -57,6 +58,9 @@ def get_attn_backend( ...@@ -57,6 +58,9 @@ def get_attn_backend(
ROCmFlashAttentionBackend) ROCmFlashAttentionBackend)
return ROCmFlashAttentionBackend return ROCmFlashAttentionBackend
elif backend == _Backend.TORCH_SDPA: elif backend == _Backend.TORCH_SDPA:
# TODO: make XPU backend available here.
assert is_cpu(), RuntimeError(
"Torch SDPA backend is only used for the CPU device.")
logger.info("Using Torch SDPA backend.") logger.info("Using Torch SDPA backend.")
from vllm.attention.backends.torch_sdpa import TorchSDPABackend from vllm.attention.backends.torch_sdpa import TorchSDPABackend
return TorchSDPABackend return TorchSDPABackend
...@@ -66,6 +70,10 @@ def get_attn_backend( ...@@ -66,6 +70,10 @@ def get_attn_backend(
"Please make sure --enforce-eager is set.") "Please make sure --enforce-eager is set.")
from vllm.attention.backends.flashinfer import FlashInferBackend from vllm.attention.backends.flashinfer import FlashInferBackend
return FlashInferBackend return FlashInferBackend
elif backend == _Backend.PALLAS:
logger.info("Using Pallas backend.")
from vllm.attention.backends.pallas import PallasAttentionBackend
return PallasAttentionBackend
else: else:
raise ValueError("Invalid attention backend.") raise ValueError("Invalid attention backend.")
...@@ -80,7 +88,6 @@ def which_attn_to_use( ...@@ -80,7 +88,6 @@ def which_attn_to_use(
block_size: int, block_size: int,
) -> _Backend: ) -> _Backend:
"""Returns which flash attention backend to use.""" """Returns which flash attention backend to use."""
# Default case. # Default case.
selected_backend = _Backend.FLASH_ATTN selected_backend = _Backend.FLASH_ATTN
...@@ -100,6 +107,11 @@ def which_attn_to_use( ...@@ -100,6 +107,11 @@ def which_attn_to_use(
logger.info("Cannot use %s backend on CPU.", selected_backend) logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA return _Backend.TORCH_SDPA
if is_tpu():
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS
if is_hip(): if is_hip():
# AMD GPUs. # AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
......
...@@ -11,7 +11,8 @@ from vllm.logger import init_logger ...@@ -11,7 +11,8 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models import ModelRegistry
from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import get_cpu_memory, is_cpu, is_hip, is_neuron from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_tpu)
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup from ray.util.placement_group import PlacementGroup
...@@ -212,7 +213,7 @@ class ModelConfig: ...@@ -212,7 +213,7 @@ class ModelConfig:
f"{self.quantization} quantization is currently not " f"{self.quantization} quantization is currently not "
f"supported in ROCm.") f"supported in ROCm.")
if (self.quantization if (self.quantization
not in ["marlin", "gptq_marlin_24", "gptq_marlin"]): not in ("fp8", "marlin", "gptq_marlin_24", "gptq_marlin")):
logger.warning( logger.warning(
"%s quantization is not fully " "%s quantization is not fully "
"optimized yet. The speed can be slower than " "optimized yet. The speed can be slower than "
...@@ -605,12 +606,11 @@ class ParallelConfig: ...@@ -605,12 +606,11 @@ class ParallelConfig:
if self.distributed_executor_backend is None and self.world_size > 1: if self.distributed_executor_backend is None and self.world_size > 1:
# We use multiprocessing by default if world_size fits on the # We use multiprocessing by default if world_size fits on the
# current node and we aren't in a ray placement group. # current node and we aren't in a ray placement group.
from torch.cuda import device_count
from vllm.executor import ray_utils from vllm.executor import ray_utils
backend = "mp" backend = "mp"
ray_found = ray_utils.ray is not None ray_found = ray_utils.ray is not None
if device_count() < self.world_size: if cuda_device_count_stateless() < self.world_size:
if not ray_found: if not ray_found:
raise ValueError("Unable to load Ray which is " raise ValueError("Unable to load Ray which is "
"required for multi-node inference") "required for multi-node inference")
...@@ -748,6 +748,8 @@ class DeviceConfig: ...@@ -748,6 +748,8 @@ class DeviceConfig:
# Automated device type detection # Automated device type detection
if is_neuron(): if is_neuron():
self.device_type = "neuron" self.device_type = "neuron"
elif is_tpu():
self.device_type = "tpu"
elif is_cpu(): elif is_cpu():
self.device_type = "cpu" self.device_type = "cpu"
else: else:
...@@ -761,6 +763,8 @@ class DeviceConfig: ...@@ -761,6 +763,8 @@ class DeviceConfig:
# Some device types require processing inputs on CPU # Some device types require processing inputs on CPU
if self.device_type in ["neuron"]: if self.device_type in ["neuron"]:
self.device = torch.device("cpu") self.device = torch.device("cpu")
elif self.device_type in ["tpu"]:
self.device = None
else: else:
# Set device with device type # Set device with device type
self.device = torch.device(self.device_type) self.device = torch.device(self.device_type)
......
...@@ -50,8 +50,8 @@ class SchedulingBudget: ...@@ -50,8 +50,8 @@ class SchedulingBudget:
""" """
token_budget: int token_budget: int
max_num_seqs: int max_num_seqs: int
_requeset_ids_num_batched_tokens: Set[str] = field(default_factory=set) _request_ids_num_batched_tokens: Set[str] = field(default_factory=set)
_requeset_ids_num_curr_seqs: Set[str] = field(default_factory=set) _request_ids_num_curr_seqs: Set[str] = field(default_factory=set)
_num_batched_tokens: int = 0 _num_batched_tokens: int = 0
_num_curr_seqs: int = 0 _num_curr_seqs: int = 0
...@@ -65,28 +65,28 @@ class SchedulingBudget: ...@@ -65,28 +65,28 @@ class SchedulingBudget:
return self.token_budget - self.num_batched_tokens return self.token_budget - self.num_batched_tokens
def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int): def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int):
if req_id in self._requeset_ids_num_batched_tokens: if req_id in self._request_ids_num_batched_tokens:
return return
self._requeset_ids_num_batched_tokens.add(req_id) self._request_ids_num_batched_tokens.add(req_id)
self._num_batched_tokens += num_batched_tokens self._num_batched_tokens += num_batched_tokens
def subtract_num_batched_tokens(self, req_id: str, def subtract_num_batched_tokens(self, req_id: str,
num_batched_tokens: int): num_batched_tokens: int):
if req_id in self._requeset_ids_num_batched_tokens: if req_id in self._request_ids_num_batched_tokens:
self._requeset_ids_num_batched_tokens.remove(req_id) self._request_ids_num_batched_tokens.remove(req_id)
self._num_batched_tokens -= num_batched_tokens self._num_batched_tokens -= num_batched_tokens
def add_num_seqs(self, req_id: str, num_curr_seqs: int): def add_num_seqs(self, req_id: str, num_curr_seqs: int):
if req_id in self._requeset_ids_num_curr_seqs: if req_id in self._request_ids_num_curr_seqs:
return return
self._requeset_ids_num_curr_seqs.add(req_id) self._request_ids_num_curr_seqs.add(req_id)
self._num_curr_seqs += num_curr_seqs self._num_curr_seqs += num_curr_seqs
def subtract_num_seqs(self, req_id: str, num_curr_seqs: int): def subtract_num_seqs(self, req_id: str, num_curr_seqs: int):
if req_id in self._requeset_ids_num_curr_seqs: if req_id in self._request_ids_num_curr_seqs:
self._requeset_ids_num_curr_seqs.remove(req_id) self._request_ids_num_curr_seqs.remove(req_id)
self._num_curr_seqs -= num_curr_seqs self._num_curr_seqs -= num_curr_seqs
@property @property
......
from collections import namedtuple from typing import Any, Dict, Optional, Union
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from torch.distributed import ProcessGroup import torch.distributed
from .parallel_state import (get_cpu_world_group, get_pp_pynccl_communicator, from .parallel_state import get_tp_group
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_ca_communicator,
get_tp_pynccl_communicator)
@dataclass
class GraphCaptureContext:
stream: torch.cuda.Stream
@contextmanager
def graph_capture():
"""
`graph_capture` is a context manager which should surround the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed. It returns a `GraphCaptureContext` object which contains the
necessary data for the graph capture. Currently, it only contains the
stream that the graph capture is running on. This stream is set to the
current CUDA stream when the context manager is entered and reset to the
default stream when the context manager is exited. This is to ensure that
the graph capture is running on a separate stream from the default stream,
in order to explicitly distinguish the kernels to capture
from other kernels possibly launched on background in the default stream.
"""
stream = torch.cuda.Stream()
graph_capture_context = GraphCaptureContext(stream)
ca_comm = get_tp_ca_communicator()
maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
with torch.cuda.stream(stream), maybe_ca_context:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# custom allreduce | enabled | enabled |
# PyNccl | disabled| enabled |
# torch.distributed | enabled | disabled|
#
# Note that custom allreduce will have a runtime check, if the tensor
# size is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
tp_pynccl_comm = get_tp_pynccl_communicator()
pp_pynccl_comm = get_pp_pynccl_communicator()
if not tp_pynccl_comm:
maybe_tp_pynccl_context = nullcontext()
else:
maybe_tp_pynccl_context = tp_pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream())
if not pp_pynccl_comm:
maybe_pp_pynccl_context = nullcontext()
else:
maybe_pp_pynccl_context = pp_pynccl_comm.change_state(
enable=True, stream=torch.cuda.current_stream())
with maybe_tp_pynccl_context, maybe_pp_pynccl_context:
yield graph_capture_context
def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
"""All-reduce the input tensor across model parallel group. """All-reduce the input tensor across model parallel group."""
return get_tp_group().all_reduce(input_)
NOTE: This operation will be applied in-place on the input tensor if
disable_custom_all_reduce is set to True. Otherwise, this operation may or
may not be applied in place depending on whether custom all reduce is
invoked for a particular tensor, which further depends on the tensor size
and GPU topology.
TLDR: always assume this function modifies its input, but use the return
value as the output.
"""
ca_comm = get_tp_ca_communicator()
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:
return input_
if ca_comm is not None:
out = ca_comm.custom_all_reduce(input_)
if out is not None:
return out
pynccl_comm = get_tp_pynccl_communicator()
if (pynccl_comm is not None and not pynccl_comm.disabled):
pynccl_comm.all_reduce(input_)
else:
torch.distributed.all_reduce(input_,
group=get_tensor_model_parallel_group())
return input_
def tensor_model_parallel_all_gather(input_: torch.Tensor, def tensor_model_parallel_all_gather(input_: torch.Tensor,
dim: int = -1) -> torch.Tensor: dim: int = -1) -> torch.Tensor:
"""All-gather the input tensor across model parallel group.""" """All-gather the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size() return get_tp_group().all_gather(input_, dim)
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
input_size = input_.size()
# Allocate output tensor.
output_tensor = torch.empty((world_size, ) + input_size,
dtype=input_.dtype,
device=input_.device)
# All-gather.
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=get_tensor_model_parallel_group())
# Reshape
output_tensor = output_tensor.movedim(0, dim)
output_tensor = output_tensor.reshape(input_size[:dim] +
(world_size * input_size[dim], ) +
input_size[dim + 1:])
return output_tensor
def tensor_model_parallel_gather(input_: torch.Tensor, def tensor_model_parallel_gather(input_: torch.Tensor,
dst: int = 0, dst: int = 0,
dim: int = -1) -> torch.Tensor: dim: int = -1) -> torch.Tensor:
"""Gather the input tensor across model parallel group. """Gather the input tensor across model parallel group."""
return get_tp_group().gather(input_, dst, dim)
NOTE: We assume that the input tensor is on the same device across
all the ranks.
"""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size == 1:
return input_
assert -input_.dim() <= dim < input_.dim(), (
f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
if dim < 0:
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
if get_tensor_model_parallel_rank() == dst:
gather_list = [torch.empty_like(input_) for _ in range(world_size)]
else:
gather_list = None
# Gather.
torch.distributed.gather(input_,
gather_list,
dst=dst,
group=get_tensor_model_parallel_group())
if get_tensor_model_parallel_rank() == dst:
output_tensor = torch.cat(gather_list, dim=dim)
else:
output_tensor = None
return output_tensor
def broadcast(input_: torch.Tensor,
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input tensor."""
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return input_
# Broadcast.
torch.distributed.broadcast(input_, src=src, group=group)
return input_
def broadcast_object_list(obj_list: List[Any],
src: int = 0,
group: Optional[ProcessGroup] = None):
"""Broadcast the input object list."""
group = group or torch.distributed.group.WORLD
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
# Bypass the function if we are using only 1 GPU.
world_size = torch.distributed.get_world_size(group=group)
if world_size == 1:
return obj_list
# Broadcast.
torch.distributed.broadcast_object_list(obj_list, src=src, group=group)
return obj_list
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
def _split_tensor_dict(
tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list = []
tensor_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
# Note: we cannot use `value.device` here,
# because it contains not only the device type but also the device
# index (e.g. "cuda:0"). We only need the device type.
# receiving side will set the device index.
device = "cpu" if value.is_cpu else "cuda"
metadata_list.append(
(key, TensorMetadata(device, value.dtype, value.size())))
tensor_list.append(value)
else:
metadata_list.append((key, value))
return metadata_list, tensor_list
def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0,
group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.
`group` is used to broadcast the tensors, while `metadata_group` is used
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
"""
# Bypass the function if we are using only 1 GPU.
if (not torch.distributed.is_initialized()
or torch.distributed.get_world_size(group=group) == 1):
return tensor_dict
group = group or torch.distributed.group.WORLD
metadata_group = metadata_group or get_cpu_world_group()
ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})"
rank = torch.distributed.get_rank()
if rank == src:
metadata_list: List[Tuple[Any, Any]] = []
assert isinstance(
tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
# `metadata_list` lives in CPU memory.
# `broadcast_object_list` involves serialization and deserialization,
# all happening on CPU. Therefore, we can use the CPU group.
torch.distributed.broadcast_object_list([metadata_list],
src=src,
group=metadata_group)
async_handles = []
for tensor in tensor_list:
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=metadata_group,
async_op=True)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True)
async_handles.append(handle)
for async_handle in async_handles:
async_handle.wait()
else: def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor,
recv_metadata_list = [None] Any]]] = None,
torch.distributed.broadcast_object_list(recv_metadata_list, src: int = 0):
src=src, if not torch.distributed.is_initialized():
group=metadata_group)
assert recv_metadata_list[0] is not None
tensor_dict = {}
async_handles = []
for key, value in recv_metadata_list[0]:
if isinstance(value, TensorMetadata):
tensor = torch.empty(value.size,
dtype=value.dtype,
device=value.device)
if tensor.numel() == 0:
# Skip broadcasting empty tensors.
tensor_dict[key] = tensor
continue
if tensor.is_cpu:
# use metadata_group for CPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=metadata_group,
async_op=True)
else:
# use group for GPU tensors
handle = torch.distributed.broadcast(tensor,
src=src,
group=group,
async_op=True)
async_handles.append(handle)
tensor_dict[key] = tensor
else:
tensor_dict[key] = value
for async_handle in async_handles:
async_handle.wait()
return tensor_dict return tensor_dict
return get_tp_group().broadcast_tensor_dict(tensor_dict, src)
...@@ -9,9 +9,9 @@ import vllm.envs as envs ...@@ -9,9 +9,9 @@ import vllm.envs as envs
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed.device_communicators.custom_all_reduce_utils import ( from vllm.distributed.device_communicators.custom_all_reduce_utils import (
gpu_p2p_access_check) gpu_p2p_access_check)
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import is_in_the_same_node
get_local_rank, get_tensor_model_parallel_cpu_group, is_in_the_same_node)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless
try: try:
import pynvml import pynvml
...@@ -86,8 +86,8 @@ class CustomAllreduce: ...@@ -86,8 +86,8 @@ class CustomAllreduce:
# max_size: max supported allreduce size # max_size: max supported allreduce size
def __init__(self, def __init__(self,
group: Optional[ProcessGroup] = None, group: ProcessGroup,
device: Optional[Union[int, str, torch.device]] = None, device: Union[int, str, torch.device],
max_size=8192 * 1024) -> None: max_size=8192 * 1024) -> None:
""" """
Args: Args:
...@@ -107,7 +107,6 @@ class CustomAllreduce: ...@@ -107,7 +107,6 @@ class CustomAllreduce:
# e.g. in a non-cuda environment # e.g. in a non-cuda environment
return return
group = group or get_tensor_model_parallel_cpu_group()
self.group = group self.group = group
assert dist.get_backend(group) != dist.Backend.NCCL, ( assert dist.get_backend(group) != dist.Backend.NCCL, (
...@@ -134,10 +133,7 @@ class CustomAllreduce: ...@@ -134,10 +133,7 @@ class CustomAllreduce:
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
return return
if device is None: if isinstance(device, int):
local_rank = get_local_rank()
device = torch.device(f"cuda:{local_rank}")
elif isinstance(device, int):
device = torch.device(f"cuda:{device}") device = torch.device(f"cuda:{device}")
elif isinstance(device, str): elif isinstance(device, str):
device = torch.device(device) device = torch.device(device)
...@@ -149,7 +145,7 @@ class CustomAllreduce: ...@@ -149,7 +145,7 @@ class CustomAllreduce:
if cuda_visible_devices: if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(","))) device_ids = list(map(int, cuda_visible_devices.split(",")))
else: else:
device_ids = list(range(torch.cuda.device_count())) device_ids = list(range(cuda_device_count_stateless()))
physical_device_id = device_ids[device.index] physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id], tensor = torch.tensor([physical_device_id],
......
...@@ -11,8 +11,8 @@ import torch.distributed as dist ...@@ -11,8 +11,8 @@ import torch.distributed as dist
import torch.multiprocessing as mp import torch.multiprocessing as mp
import vllm.envs as envs import vllm.envs as envs
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import cuda_device_count_stateless
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -153,7 +153,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: ...@@ -153,7 +153,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
is_distributed = dist.is_initialized() is_distributed = dist.is_initialized()
num_dev = torch.cuda.device_count() num_dev = cuda_device_count_stateless()
cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES
if cuda_visible_devices is None: if cuda_visible_devices is None:
cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) cuda_visible_devices = ",".join(str(i) for i in range(num_dev))
...@@ -162,7 +162,8 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: ...@@ -162,7 +162,8 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json" f"{VLLM_CONFIG_ROOT}/vllm/gpu_p2p_access_cache_for_{cuda_visible_devices}.json"
) )
os.makedirs(os.path.dirname(path), exist_ok=True) os.makedirs(os.path.dirname(path), exist_ok=True)
if ((not is_distributed or get_local_rank() == 0) from vllm.distributed.parallel_state import get_world_group
if ((not is_distributed or get_world_group().local_rank == 0)
and (not os.path.exists(path))): and (not os.path.exists(path))):
# only the local master process (with local_rank == 0) can # only the local master process (with local_rank == 0) can
# enter this block to calculate the cache # enter this block to calculate the cache
...@@ -174,8 +175,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool: ...@@ -174,8 +175,7 @@ def gpu_p2p_access_check(i: int, j: int) -> bool:
with open(path, "w") as f: with open(path, "w") as f:
json.dump(cache, f, indent=4) json.dump(cache, f, indent=4)
if is_distributed: if is_distributed:
cpu_world_group = get_cpu_world_group() get_world_group().barrier()
dist.barrier(cpu_world_group)
logger.info("reading GPU P2P access cache from %s", path) logger.info("reading GPU P2P access cache from %s", path)
with open(path, "r") as f: with open(path, "r") as f:
cache = json.load(f) cache = json.load(f)
......
...@@ -9,7 +9,6 @@ from torch.distributed import ProcessGroup, ReduceOp ...@@ -9,7 +9,6 @@ from torch.distributed import ProcessGroup, ReduceOp
from vllm.distributed.device_communicators.pynccl_wrapper import ( from vllm.distributed.device_communicators.pynccl_wrapper import (
NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum, NCCLLibrary, buffer_type, cudaStream_t, ncclComm_t, ncclDataTypeEnum,
ncclRedOpTypeEnum, ncclUniqueId) ncclRedOpTypeEnum, ncclUniqueId)
from vllm.distributed.parallel_state import get_cpu_world_group, get_local_rank
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -19,8 +18,8 @@ class PyNcclCommunicator: ...@@ -19,8 +18,8 @@ class PyNcclCommunicator:
def __init__( def __init__(
self, self,
group: Optional[ProcessGroup] = None, group: ProcessGroup,
device: Optional[Union[int, str, torch.device]] = None, device: Union[int, str, torch.device],
library_path: Optional[str] = None, library_path: Optional[str] = None,
): ):
""" """
...@@ -35,7 +34,6 @@ class PyNcclCommunicator: ...@@ -35,7 +34,6 @@ class PyNcclCommunicator:
is bind to a unique device. is bind to a unique device.
""" """
assert dist.is_initialized() assert dist.is_initialized()
group = get_cpu_world_group() if group is None else group
assert dist.get_backend(group) != dist.Backend.NCCL, ( assert dist.get_backend(group) != dist.Backend.NCCL, (
"PyNcclCommunicator should be attached to a non-NCCL group.") "PyNcclCommunicator should be attached to a non-NCCL group.")
self.group = group self.group = group
...@@ -77,10 +75,7 @@ class PyNcclCommunicator: ...@@ -77,10 +75,7 @@ class PyNcclCommunicator:
byte_list = tensor.tolist() byte_list = tensor.tolist()
for i, byte in enumerate(byte_list): for i, byte in enumerate(byte_list):
self.unique_id.internal[i] = byte self.unique_id.internal[i] = byte
if device is None: if isinstance(device, int):
local_rank = get_local_rank()
device = torch.device(f"cuda:{local_rank}")
elif isinstance(device, int):
device = torch.device(f"cuda:{device}") device = torch.device(f"cuda:{device}")
elif isinstance(device, str): elif isinstance(device, str):
device = torch.device(device) device = torch.device(device)
......
This diff is collapsed.
...@@ -504,7 +504,7 @@ class EngineArgs: ...@@ -504,7 +504,7 @@ class EngineArgs:
parser.add_argument("--device", parser.add_argument("--device",
type=str, type=str,
default=EngineArgs.device, default=EngineArgs.device,
choices=["auto", "cuda", "neuron", "cpu"], choices=["auto", "cuda", "neuron", "cpu", "tpu"],
help='Device type for vLLM execution.') help='Device type for vLLM execution.')
# Related to Vision-language models such as llava # Related to Vision-language models such as llava
......
...@@ -375,6 +375,9 @@ class AsyncLLMEngine: ...@@ -375,6 +375,9 @@ class AsyncLLMEngine:
if engine_config.device_config.device_type == "neuron": if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
from vllm.executor.tpu_executor import TPUExecutorAsync
executor_class = TPUExecutorAsync
elif engine_config.device_config.device_type == "cpu": elif engine_config.device_config.device_type == "cpu":
assert distributed_executor_backend is None, ( assert distributed_executor_backend is None, (
"Distributed execution is not supported with the CPU backend.") "Distributed execution is not supported with the CPU backend.")
......
...@@ -6,7 +6,6 @@ from typing import Type, TypeVar, Union ...@@ -6,7 +6,6 @@ from typing import Type, TypeVar, Union
from transformers import GenerationConfig, PreTrainedTokenizer from transformers import GenerationConfig, PreTrainedTokenizer
import vllm
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, ParallelConfig, LoRAConfig, ModelConfig, ParallelConfig,
SchedulerConfig, SpeculativeConfig, SchedulerConfig, SpeculativeConfig,
...@@ -38,6 +37,7 @@ from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, ...@@ -38,6 +37,7 @@ from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled, from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message) usage_message)
from vllm.utils import Counter from vllm.utils import Counter
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__) logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5 _LOCAL_LOGGING_INTERVAL_SEC = 5
...@@ -169,7 +169,7 @@ class LLMEngine: ...@@ -169,7 +169,7 @@ class LLMEngine:
"enforce_eager=%s, kv_cache_dtype=%s, " "enforce_eager=%s, kv_cache_dtype=%s, "
"quantization_param_path=%s, device_config=%s, " "quantization_param_path=%s, device_config=%s, "
"decoding_config=%r, seed=%d, served_model_name=%s)", "decoding_config=%r, seed=%d, served_model_name=%s)",
vllm.__version__, VLLM_VERSION,
model_config.model, model_config.model,
speculative_config, speculative_config,
model_config.tokenizer, model_config.tokenizer,
...@@ -341,6 +341,9 @@ class LLMEngine: ...@@ -341,6 +341,9 @@ class LLMEngine:
if engine_config.device_config.device_type == "neuron": if engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutor from vllm.executor.neuron_executor import NeuronExecutor
executor_class = NeuronExecutor executor_class = NeuronExecutor
elif engine_config.device_config.device_type == "tpu":
from vllm.executor.tpu_executor import TPUExecutor
executor_class = TPUExecutor
elif engine_config.device_config.device_type == "cpu": elif engine_config.device_config.device_type == "cpu":
from vllm.executor.cpu_executor import CPUExecutor from vllm.executor.cpu_executor import CPUExecutor
executor_class = CPUExecutor executor_class = CPUExecutor
......
...@@ -545,11 +545,13 @@ class LLM: ...@@ -545,11 +545,13 @@ class LLM:
total=num_requests, total=num_requests,
desc="Processed prompts", desc="Processed prompts",
dynamic_ncols=True, dynamic_ncols=True,
postfix=f"Generation Speed: {0:.2f} toks/s", postfix=(f"est. speed input: {0:.2f} toks/s, "
f"output: {0:.2f} toks/s"),
) )
# Run the engine. # Run the engine.
outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = []
total_toks = 0 total_in_toks = 0
total_out_toks = 0
while self.llm_engine.has_unfinished_requests(): while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step() step_outputs = self.llm_engine.step()
for output in step_outputs: for output in step_outputs:
...@@ -558,10 +560,15 @@ class LLM: ...@@ -558,10 +560,15 @@ class LLM:
if use_tqdm: if use_tqdm:
if isinstance(output, RequestOutput): if isinstance(output, RequestOutput):
# Calculate tokens only for RequestOutput # Calculate tokens only for RequestOutput
total_toks += sum( total_in_toks += len(output.prompt_token_ids)
in_spd = total_in_toks / pbar.format_dict["elapsed"]
total_out_toks += sum(
len(stp.token_ids) for stp in output.outputs) len(stp.token_ids) for stp in output.outputs)
spd = total_toks / pbar.format_dict["elapsed"] out_spd = total_out_toks / pbar.format_dict[
pbar.postfix = f"Generation Speed: {spd:.2f} toks/s" "elapsed"]
pbar.postfix = (
f"est. speed input: {in_spd:.2f} toks/s, "
f"output: {out_spd:.2f} toks/s")
pbar.update(1) pbar.update(1)
if use_tqdm: if use_tqdm:
pbar.close() pbar.close()
......
...@@ -15,7 +15,6 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse ...@@ -15,7 +15,6 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
from prometheus_client import make_asgi_app from prometheus_client import make_asgi_app
from starlette.routing import Mount from starlette.routing import Mount
import vllm
import vllm.envs as envs import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
...@@ -29,6 +28,7 @@ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion ...@@ -29,6 +28,7 @@ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds TIMEOUT_KEEP_ALIVE = 5 # seconds
...@@ -93,7 +93,7 @@ async def show_available_models(): ...@@ -93,7 +93,7 @@ async def show_available_models():
@app.get("/version") @app.get("/version")
async def show_version(): async def show_version():
ver = {"version": vllm.__version__} ver = {"version": VLLM_VERSION}
return JSONResponse(content=ver) return JSONResponse(content=ver)
...@@ -174,7 +174,7 @@ if __name__ == "__main__": ...@@ -174,7 +174,7 @@ if __name__ == "__main__":
raise ValueError(f"Invalid middleware {middleware}. " raise ValueError(f"Invalid middleware {middleware}. "
f"Must be a function or a class.") f"Must be a function or a class.")
logger.info("vLLM API server version %s", vllm.__version__) logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args) logger.info("args: %s", args)
if args.served_model_name is not None: if args.served_model_name is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment