Commit 539aa992 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.2' into v0.6.2-dev

parents 93872128 7193774b
...@@ -27,29 +27,27 @@ class ipex_ops: ...@@ -27,29 +27,27 @@ class ipex_ops:
@staticmethod @staticmethod
def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x) ipex.llm.functional.silu_and_mul(x, out)
ipex.llm.functional.silu_mul(x1, x2, out)
@staticmethod @staticmethod
def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x) ipex.llm.functional.gelu_and_mul(x, out)
ipex.llm.functional.gelu_mul(x1, x2, out, "none")
@staticmethod @staticmethod
def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_tanh_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
x1, x2 = ipex_ops._reshape_activation_tensor(x) ipex.llm.functional.gelu_and_mul(x, out)
ipex.llm.functional.gelu_mul(x1, x2, out, "tanh")
@staticmethod @staticmethod
def gelu_fast(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_fast(x: torch.Tensor) -> torch.Tensor:
out.copy_(torch.nn.functional.gelu(x)) return torch.nn.functional.gelu(x)
@staticmethod @staticmethod
def gelu_new(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_new(x: torch.Tensor) -> torch.Tensor:
out.copy_(torch.nn.functional.gelu(x)) return torch.nn.functional.gelu(x)
# TODO add implementation of gelu_quick here @staticmethod
# def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None: def gelu_quick(out: torch.Tensor, x: torch.Tensor) -> None:
ipex.llm.functional.gelu_quick(x, out)
@staticmethod @staticmethod
def paged_attention_v1( def paged_attention_v1(
...@@ -160,29 +158,10 @@ class ipex_ops: ...@@ -160,29 +158,10 @@ class ipex_ops:
cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim] cos_sin_cache: torch.Tensor, # [cos_sin_dim, rot_dim]
is_neox: bool, is_neox: bool,
) -> None: ) -> None:
if positions.dim() == 1: rot_dim = cos_sin_cache.size(1)
positions = positions.unsqueeze(0) ipex.llm.functional.rotary_embedding_batched(positions, query, key,
query = query.unsqueeze(0) head_size, cos_sin_cache,
key = key.unsqueeze(0) is_neox, rot_dim)
rotary_dim = cos_sin_cache.size(1)
query = query.view(*query.shape[:-1], -1, head_size)
key = key.view(*key.shape[:-1], -1, head_size)
query_rot = query[..., :rotary_dim]
key_rot = key[..., :rotary_dim]
cos_sin = cos_sin_cache[positions.long()]
cos, sin = cos_sin.chunk(2, dim=-1)
if is_neox:
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions)
@staticmethod @staticmethod
def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor, def batched_rotary_embedding(positions: torch.Tensor, query: torch.Tensor,
...@@ -190,37 +169,15 @@ class ipex_ops: ...@@ -190,37 +169,15 @@ class ipex_ops:
cos_sin_cache: torch.Tensor, is_neox: bool, cos_sin_cache: torch.Tensor, is_neox: bool,
rot_dim: int, rot_dim: int,
cos_sin_cache_offsets: torch.Tensor) -> None: cos_sin_cache_offsets: torch.Tensor) -> None:
if positions.dim() == 1: ipex.llm.functional.rotary_embedding_batched(positions, query, key,
positions = positions.unsqueeze(0) head_size, cos_sin_cache,
query = query.unsqueeze(0) is_neox, rot_dim,
key = key.unsqueeze(0) cos_sin_cache_offsets)
cos_sin_cache_offsets = cos_sin_cache_offsets.view_as(positions)
rotary_dim = cos_sin_cache.size(1)
query = query.view(*query.shape[:-1], -1, head_size)
key = key.view(*key.shape[:-1], -1, head_size)
query_rot = query[..., :rotary_dim]
key_rot = key[..., :rotary_dim]
cos_sin = cos_sin_cache[torch.add(positions,
cos_sin_cache_offsets).long()]
cos, sin = cos_sin.chunk(2, dim=-1)
if is_neox:
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
ipex.llm.functional.rotary_embedding(query_rot, key_rot, sin, cos,
rotary_dim, is_neox, positions)
@staticmethod @staticmethod
def rms_norm(out: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, def rms_norm(input: torch.Tensor, weight: torch.Tensor,
epsilon: float) -> None: epsilon: float) -> torch.Tensor:
tmp = ipex.llm.functional.rms_norm(input, weight, epsilon) return ipex.llm.functional.rms_norm(input, weight, epsilon)
out.copy_(tmp)
@staticmethod @staticmethod
def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
...@@ -246,11 +203,14 @@ class ipex_ops: ...@@ -246,11 +203,14 @@ class ipex_ops:
return_softmax: bool, return_softmax: bool,
gen_: torch.Generator, gen_: torch.Generator,
) -> None: ) -> None:
ipex.llm.functional.varlen_attention(query, key, value, out, seqlen_q, ipex.llm.functional.varlen_attention(query.contiguous(),
seqlen_k, max_seqlen_q, key.contiguous(),
max_seqlen_k, pdropout, value.contiguous(), out,
softmax_scale, zero_tensors, seqlen_q.int(), seqlen_k.int(),
is_causal, return_softmax, gen_) max_seqlen_q, max_seqlen_k,
pdropout, softmax_scale,
zero_tensors, is_causal,
return_softmax, gen_)
@staticmethod @staticmethod
def reshape_and_cache( def reshape_and_cache(
......
...@@ -42,7 +42,7 @@ def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]: ...@@ -42,7 +42,7 @@ def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]:
def get_adapter(adapter_id: int, def get_adapter(adapter_id: int,
registered_adapters: Dict[int, Any]) -> Optional[Any]: registered_adapters: Dict[int, Any]) -> Optional[Any]:
return registered_adapters.get(adapter_id, None) return registered_adapters.get(adapter_id)
## worker functions ## worker functions
......
...@@ -79,7 +79,7 @@ class VideoAsset: ...@@ -79,7 +79,7 @@ class VideoAsset:
return ret return ret
@property @property
def np_ndarrays(self) -> List[npt.NDArray]: def np_ndarrays(self) -> npt.NDArray:
video_path = download_video_asset(self.name) video_path = download_video_asset(self.name)
ret = video_to_ndarrays(video_path, self.num_frames) ret = video_to_ndarrays(video_path, self.num_frames)
return ret return ret
...@@ -156,18 +156,27 @@ class AttentionState(ABC, Generic[T]): ...@@ -156,18 +156,27 @@ class AttentionState(ABC, Generic[T]):
... ...
@abstractmethod @abstractmethod
def graph_capture_get_metadata_for_batch(self, batch_size: int) -> T: def graph_capture_get_metadata_for_batch(
self,
batch_size: int,
is_encoder_decoder_model: bool = False) -> T:
"""Get attention metadata for CUDA graph capture of batch_size.""" """Get attention metadata for CUDA graph capture of batch_size."""
... ...
@abstractmethod @abstractmethod
def get_graph_input_buffers(self, attn_metadata: T) -> Dict[str, Any]: def get_graph_input_buffers(
self,
attn_metadata: T,
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
"""Get attention-specific input buffers for CUDA graph capture.""" """Get attention-specific input buffers for CUDA graph capture."""
... ...
@abstractmethod @abstractmethod
def prepare_graph_input_buffers(self, input_buffers: Dict[str, Any], def prepare_graph_input_buffers(
attn_metadata: T) -> None: self,
input_buffers: Dict[str, Any],
attn_metadata: T,
is_encoder_decoder_model: bool = False) -> None:
"""In-place modify input buffers dict for CUDA graph replay.""" """In-place modify input buffers dict for CUDA graph replay."""
... ...
......
...@@ -19,8 +19,13 @@ if TYPE_CHECKING: ...@@ -19,8 +19,13 @@ if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder, from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata) ModelInputForGPUWithSamplingMetadata)
from vllm_flash_attn import flash_attn_varlen_func as _flash_attn_varlen_func # yapf: disable
from vllm_flash_attn import flash_attn_with_kvcache as _flash_attn_with_kvcache from vllm.vllm_flash_attn import (
flash_attn_varlen_func as _flash_attn_varlen_func)
from vllm.vllm_flash_attn import (
flash_attn_with_kvcache as _flash_attn_with_kvcache)
# yapf: enable
@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[]) @torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[])
......
...@@ -172,7 +172,8 @@ class FlashInferState(AttentionState): ...@@ -172,7 +172,8 @@ class FlashInferState(AttentionState):
state._prefill_wrapper = self._get_prefill_wrapper() state._prefill_wrapper = self._get_prefill_wrapper()
return state return state
def graph_capture_get_metadata_for_batch(self, batch_size: int): def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False):
assert self._is_graph_capturing assert self._is_graph_capturing
_indptr_buffer = self._graph_indptr_buffer[:batch_size + 1] _indptr_buffer = self._graph_indptr_buffer[:batch_size + 1]
_last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size] _last_page_len_buffer = self._graph_last_page_len_buffer[:batch_size]
...@@ -232,12 +233,17 @@ class FlashInferState(AttentionState): ...@@ -232,12 +233,17 @@ class FlashInferState(AttentionState):
attn_metadata.begin_forward() attn_metadata.begin_forward()
return attn_metadata return attn_metadata
def get_graph_input_buffers(self, attn_metadata): def get_graph_input_buffers(self,
attn_metadata,
is_encoder_decoder_model: bool = False):
return { return {
"slot_mapping": attn_metadata.slot_mapping, "slot_mapping": attn_metadata.slot_mapping,
} }
def prepare_graph_input_buffers(self, input_buffers, attn_metadata): def prepare_graph_input_buffers(self,
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False):
return return
def begin_forward(self, model_input): def begin_forward(self, model_input):
......
...@@ -49,14 +49,18 @@ class IpexAttnBackend(AttentionBackend): ...@@ -49,14 +49,18 @@ class IpexAttnBackend(AttentionBackend):
dst_kv_cache: torch.Tensor, dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor, src_to_dst: torch.Tensor,
) -> None: ) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) from vllm._ipex_ops import ipex_ops as ops
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod @staticmethod
def copy_blocks( def copy_blocks(
kv_caches: List[torch.Tensor], kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor, src_to_dists: torch.Tensor,
) -> None: ) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists) from vllm._ipex_ops import ipex_ops as ops
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)
@dataclass @dataclass
......
"""Attention layer ROCm GPUs.""" """Attention layer ROCm GPUs."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType) AttentionMetadata, AttentionType)
from vllm.attention.backends.utils import (CommonAttentionState, from vllm.attention.backends.utils import (CommonAttentionState,
...@@ -12,9 +13,16 @@ from vllm.attention.backends.utils import (CommonAttentionState, ...@@ -12,9 +13,16 @@ from vllm.attention.backends.utils import (CommonAttentionState,
from vllm.attention.ops.paged_attn import (PagedAttention, from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata) PagedAttentionMetadata)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
logger = init_logger(__name__) logger = init_logger(__name__)
_PARTITION_SIZE_ROCM = 512
_ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName
class ROCmFlashAttentionBackend(AttentionBackend): class ROCmFlashAttentionBackend(AttentionBackend):
...@@ -175,6 +183,59 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -175,6 +183,59 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
) )
return self._cached_decode_metadata return self._cached_decode_metadata
def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int, num_seqs: int, num_queries: int):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )
assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )
assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )
assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)
ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)
class ROCmFlashAttentionMetadataBuilder( class ROCmFlashAttentionMetadataBuilder(
CommonMetadataBuilder[ROCmFlashAttentionMetadata]): CommonMetadataBuilder[ROCmFlashAttentionMetadata]):
...@@ -297,7 +358,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -297,7 +358,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else: else:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn # if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either # either
if torch.cuda.get_device_capability()[0] != 9: if not current_platform.has_device_capability(90):
self.use_naive_attn = True self.use_naive_attn = True
else: else:
try: try:
...@@ -507,20 +568,64 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -507,20 +568,64 @@ class ROCmFlashAttentionImpl(AttentionImpl):
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
# Decoding run. # Decoding run.
output[num_prefill_tokens:] = PagedAttention.forward_decode( # Whether to use rocm custom paged attention or not
decode_query, num_seqs, num_heads, head_size = decode_query.shape
key_cache, block_size = value_cache.shape[3]
value_cache, gqa_ratio = num_heads // self.num_kv_heads
decode_meta.block_tables, use_custom = _use_rocm_custom_paged_attention(
decode_meta.seq_lens_tensor, decode_query.dtype, head_size, block_size, gqa_ratio,
decode_meta.max_decode_seq_len, decode_meta.max_decode_seq_len)
self.kv_cache_dtype, if use_custom:
self.num_kv_heads, max_seq_len = decode_meta.max_decode_seq_len
self.scale, max_num_partitions = (
self.alibi_slopes, (max_seq_len + _PARTITION_SIZE_ROCM - 1) //
k_scale, _PARTITION_SIZE_ROCM)
v_scale, assert _PARTITION_SIZE_ROCM % block_size == 0
) tmp_output = torch.empty(
size=(num_seqs, num_heads, max_num_partitions, head_size),
dtype=output.dtype,
device=output.device,
)
exp_sums = torch.empty(
size=(num_seqs, num_heads, max_num_partitions),
dtype=torch.float32,
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_rocm(
output[num_prefill_tokens:],
exp_sums,
max_logits,
tmp_output,
decode_query,
key_cache,
value_cache,
self.num_kv_heads,
self.scale,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
block_size,
max_seq_len,
self.alibi_slopes,
self.kv_cache_dtype,
k_scale,
v_scale,
)
else:
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
key_cache,
value_cache,
decode_meta.block_tables,
decode_meta.seq_lens_tensor,
decode_meta.max_decode_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
k_scale,
v_scale,
)
# Reshape the output tensor. # Reshape the output tensor.
return output.view(num_tokens, hidden_size) return output.view(num_tokens, hidden_size)
...@@ -558,4 +663,14 @@ def _sdpa_attention( ...@@ -558,4 +663,14 @@ def _sdpa_attention(
output[start:end, :, :] = sub_out output[start:end, :, :] = sub_out
start = end start = end
return output return output
\ No newline at end of file
def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
max_seq_len: int) -> bool:
# rocm custom page attention not support on navi (gfx1*)
return (not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
...@@ -33,10 +33,8 @@ def is_block_tables_empty(block_tables: Union[None, Dict]): ...@@ -33,10 +33,8 @@ def is_block_tables_empty(block_tables: Union[None, Dict]):
""" """
if block_tables is None: if block_tables is None:
return True return True
if isinstance(block_tables, dict) and all( return (isinstance(block_tables, dict)
value is None for value in block_tables.values()): and all(value is None for value in block_tables.values()))
return True
return False
def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int, def compute_slot_mapping_start_idx(is_prompt: bool, query_len: int,
...@@ -304,7 +302,8 @@ class CommonAttentionState(AttentionState): ...@@ -304,7 +302,8 @@ class CommonAttentionState(AttentionState):
assert self._is_graph_capturing assert self._is_graph_capturing
return self.__class__(self.runner) return self.__class__(self.runner)
def graph_capture_get_metadata_for_batch(self, batch_size: int): def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False):
assert self._is_graph_capturing assert self._is_graph_capturing
attn_metadata = self.runner.attn_backend.make_metadata( attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0, num_prefills=0,
...@@ -322,21 +321,121 @@ class CommonAttentionState(AttentionState): ...@@ -322,21 +321,121 @@ class CommonAttentionState(AttentionState):
block_tables=self._graph_block_tables[:batch_size], block_tables=self._graph_block_tables[:batch_size],
use_cuda_graph=True, use_cuda_graph=True,
) )
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._update_captured_metadata_for_enc_dec_model(
batch_size=batch_size, attn_metadata=attn_metadata)
return attn_metadata return attn_metadata
def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]: def get_graph_input_buffers(
return { self,
attn_metadata,
is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
input_buffers = {
"slot_mapping": attn_metadata.slot_mapping, "slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor, "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables, "block_tables": attn_metadata.decode_metadata.block_tables,
} }
if is_encoder_decoder_model:
def prepare_graph_input_buffers(self, input_buffers, # The encoder decoder model works only with XFormers backend.
attn_metadata) -> None: # Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._add_additonal_input_buffers_for_enc_dec_model(
attn_metadata=attn_metadata, input_buffers=input_buffers)
return input_buffers
def prepare_graph_input_buffers(
self,
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False) -> None:
input_buffers["seq_lens_tensor"].copy_( input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True) attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_( input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True) attn_metadata.decode_metadata.block_tables, non_blocking=True)
if is_encoder_decoder_model:
# The encoder decoder model works only with XFormers backend.
# Assert the same.
assert self.runner.attn_backend.get_name() == "xformers", \
f"Expected attn_backend name to be 'xformers', but "\
f" got '{self.runner.attn_backend.get_name()}'"
self._prepare_input_buffers_for_enc_dec_model(
attn_metadata, input_buffers)
def begin_forward(self, model_input) -> None: def begin_forward(self, model_input) -> None:
return return
def _update_captured_metadata_for_enc_dec_model(self, batch_size: int,
attn_metadata):
"""
Updates the attention metadata parameters for CUDA graph capture in an
encoder-decoder model.
This method modifies attention-related tensors and metadata required
for CUDA graph capture in encoder-decoder models. Specifically, it
updates the cross-attention and encoder sequence tensors in the
AttentionMetadata object.
"""
# During decode phase the cross_slot_mapping will be empty. Hence set
# an empty tensor for CUDA Graph capture.
attn_metadata.cross_slot_mapping = torch.tensor(
[], dtype=torch.int).cuda()
attn_metadata.cross_block_tables = torch.full(
(batch_size, self.runner.get_max_block_per_batch()),
1,
dtype=torch.int).cuda()
attn_metadata.encoder_seq_lens = torch.full((batch_size, ),
1,
dtype=torch.int).cuda()
attn_metadata.encoder_seq_lens_tensor = torch.full(
(batch_size, ), 1, dtype=torch.int).cuda()
attn_metadata.max_encoder_seq_len = self.runner.max_seq_len_to_capture
def _add_additonal_input_buffers_for_enc_dec_model(
self, attn_metadata, input_buffers: Dict[str, Any]):
"""
Saves additional input buffers specific to the encoder-decoder model
from the attention metadata.
This method extracts and stores encoder-decoder related input buffers
from the `attn_metadata` into the `input_buffers` dictionary. The
buffers include encoder sequence lengths, cross-slot mappings, and
cross-block tables, which are essential for the encoder-decoder model
during CUDA graph replay.
"""
input_buffers["encoder_seq_lens_tensor"] = (
attn_metadata.decode_metadata.encoder_seq_lens_tensor)
input_buffers["cross_slot_mapping"] = (
attn_metadata.decode_metadata.cross_slot_mapping)
input_buffers["cross_block_tables"] = (
attn_metadata.decode_metadata.cross_block_tables)
def _prepare_input_buffers_for_enc_dec_model(self, attn_metadata,
input_buffers: Dict[str,
Any]):
"""
Populates input buffers with data from the encoder-decoder model's
attention metadata.
This method fills the input buffers with encoder-decoder specific
tensors. It copies data from the `attn_metadata` and keyword arguments
(`kwargs`) into corresponding buffers in the `input_buffers` dictionary.
The copied data includes attention-related metadata as well as input
IDs and positional information for the encoder.
"""
input_buffers["encoder_seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.encoder_seq_lens_tensor,
non_blocking=True)
input_buffers["cross_slot_mapping"].copy_(
attn_metadata.decode_metadata.cross_slot_mapping,
non_blocking=True)
input_buffers["cross_block_tables"].copy_(
attn_metadata.decode_metadata.cross_block_tables,
non_blocking=True)
...@@ -8,8 +8,7 @@ from vllm.utils import is_cpu, is_hip ...@@ -8,8 +8,7 @@ from vllm.utils import is_cpu, is_hip
from .utils import (dense_to_crow_col, get_head_sliding_step, from .utils import (dense_to_crow_col, get_head_sliding_step,
get_sparse_attn_mask) get_sparse_attn_mask)
IS_COMPUTE_8_OR_ABOVE = (torch.cuda.is_available() IS_COMPUTE_8_OR_ABOVE = current_platform.has_device_capability(80)
and current_platform.get_device_capability()[0] >= 8)
if IS_COMPUTE_8_OR_ABOVE: if IS_COMPUTE_8_OR_ABOVE:
from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd from .blocksparse_attention_kernel import blocksparse_flash_attn_varlen_fwd
...@@ -36,7 +35,7 @@ class LocalStridedBlockSparseAttn(torch.nn.Module): ...@@ -36,7 +35,7 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
use_spda = is_hip() or is_cpu() or not \ use_spda = is_hip() or is_cpu() or not \
IS_COMPUTE_8_OR_ABOVE IS_COMPUTE_8_OR_ABOVE
device = device or (torch.cuda.current_device() device = device or (torch.cuda.current_device()
if torch.cuda.is_available() else "cpu") if current_platform.is_cuda_alike() else "cpu")
device = torch.device(device) device = torch.device(device)
# NOTE: vllm CPU backend support BF16 instead of FP16. # NOTE: vllm CPU backend support BF16 instead of FP16.
dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE dtype = dtype or (torch.bfloat16 if IS_COMPUTE_8_OR_ABOVE
......
...@@ -709,8 +709,7 @@ if triton.__version__ >= "2.1.0": ...@@ -709,8 +709,7 @@ if triton.__version__ >= "2.1.0":
alibi_slopes=None, alibi_slopes=None,
sliding_window=None): sliding_window=None):
cap = current_platform.get_device_capability() BLOCK = 32 if current_platform.has_device_capability(80) else 32
BLOCK = 32 if cap[0] >= 8 else 32
NUM_WARPS = 8 NUM_WARPS = 8
# need to reduce num. blocks when using fp32 # need to reduce num. blocks when using fp32
......
...@@ -203,7 +203,7 @@ def which_attn_to_use( ...@@ -203,7 +203,7 @@ def which_attn_to_use(
selected_backend = (_Backend.ROCM_FLASH if selected_backend selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend) == _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH: if selected_backend == _Backend.ROCM_FLASH:
if current_platform.get_device_capability()[0] != 9: if not current_platform.has_device_capability(90):
# not Instinct series GPUs. # not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.") logger.info("flash_attn is not supported on NAVI GPUs.")
else: else:
...@@ -212,7 +212,7 @@ def which_attn_to_use( ...@@ -212,7 +212,7 @@ def which_attn_to_use(
# FlashAttn in NVIDIA GPUs. # FlashAttn in NVIDIA GPUs.
if selected_backend == _Backend.FLASH_ATTN: if selected_backend == _Backend.FLASH_ATTN:
if current_platform.get_device_capability()[0] < 8: if not current_platform.has_device_capability(80):
# Volta and Turing NVIDIA GPUs. # Volta and Turing NVIDIA GPUs.
logger.info( logger.info(
"Cannot use FlashAttention-2 backend for Volta and Turing " "Cannot use FlashAttention-2 backend for Volta and Turing "
...@@ -244,8 +244,7 @@ def which_attn_to_use( ...@@ -244,8 +244,7 @@ def which_attn_to_use(
# FlashAttn is valid for the model, checking if the package is installed. # FlashAttn is valid for the model, checking if the package is installed.
if selected_backend == _Backend.FLASH_ATTN: if selected_backend == _Backend.FLASH_ATTN:
try: try:
import vllm_flash_attn # noqa: F401 import vllm.vllm_flash_attn # noqa: F401
from vllm.attention.backends.flash_attn import ( # noqa: F401 from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend) FlashAttentionBackend)
...@@ -258,8 +257,9 @@ def which_attn_to_use( ...@@ -258,8 +257,9 @@ def which_attn_to_use(
except ImportError: except ImportError:
logger.info( logger.info(
"Cannot use FlashAttention-2 backend because the " "Cannot use FlashAttention-2 backend because the "
"vllm_flash_attn package is not found. " "vllm.vllm_flash_attn package is not found. "
"`pip install vllm-flash-attn` for better performance.") "Make sure that vllm_flash_attn was built and installed "
"(on by default).")
selected_backend = _Backend.XFORMERS selected_backend = _Backend.XFORMERS
return selected_backend return selected_backend
......
...@@ -93,6 +93,7 @@ def run_vllm( ...@@ -93,6 +93,7 @@ def run_vllm(
download_dir: Optional[str] = None, download_dir: Optional[str] = None,
load_format: str = EngineArgs.load_format, load_format: str = EngineArgs.load_format,
disable_async_output_proc: bool = False, disable_async_output_proc: bool = False,
use_new_beam_search_impl: bool = False,
) -> float: ) -> float:
from vllm import LLM, SamplingParams from vllm import LLM, SamplingParams
llm = LLM( llm = LLM(
...@@ -169,10 +170,24 @@ def run_vllm( ...@@ -169,10 +170,24 @@ def run_vllm(
# print("Warming up...") # print("Warming up...")
# for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): # for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
# run_to_completion() # run_to_completion()
start = time.perf_counter() if not use_new_beam_search_impl:
llm.generate(prompts, sampling_params, use_tqdm=True) start = time.perf_counter()
end = time.perf_counter() llm.generate(prompts, sampling_params, use_tqdm=True)
end = time.perf_counter()
else:
assert use_beam_search
prompts = [prompt for prompt, _, _ in requests]
# output_len should be the same for all requests.
output_len = requests[0][2]
for prompt, input_len, _output_len in requests:
assert _output_len == output_len
start = time.perf_counter()
llm.beam_search(prompts,
beam_width=n,
max_tokens=output_len,
ignore_eos=True)
end = time.perf_counter()
return end - start return end - start
...@@ -229,7 +244,6 @@ async def run_vllm_async( ...@@ -229,7 +244,6 @@ async def run_vllm_async(
use_v2_block_manager=use_v2_block_manager, use_v2_block_manager=use_v2_block_manager,
disable_async_output_proc=disable_async_output_proc, disable_async_output_proc=disable_async_output_proc,
worker_use_ray=False, worker_use_ray=False,
engine_use_ray=False,
disable_log_requests=True, disable_log_requests=True,
) )
...@@ -378,7 +392,7 @@ def main(args: argparse.Namespace): ...@@ -378,7 +392,7 @@ def main(args: argparse.Namespace):
run_args.append(args.disable_frontend_multiprocessing) run_args.append(args.disable_frontend_multiprocessing)
elapsed_time = uvloop.run(run_vllm_async(*run_args)) elapsed_time = uvloop.run(run_vllm_async(*run_args))
else: else:
elapsed_time = run_vllm(*run_args) elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl)
elif args.backend == "hf": elif args.backend == "hf":
assert args.tensor_parallel_size == 1 assert args.tensor_parallel_size == 1
elapsed_time = run_hf(requests, args.model, tokenizer, args.n, elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
...@@ -450,6 +464,7 @@ if __name__ == "__main__": ...@@ -450,6 +464,7 @@ if __name__ == "__main__":
type=int, type=int,
default=1, default=1,
help='Number of iterations to run for warmup.') help='Number of iterations to run for warmup.')
parser.add_argument("--use-new-beam-search-impl", action="store_true")
parser.add_argument("--num-prompts", parser.add_argument("--num-prompts",
type=int, type=int,
default=1000, default=1000,
......
import operator
import torch
import torch.fx as fx
def fix_functionalization(graph: fx.Graph):
"""
Rewrite the graph module to replace the pattern involving
torch._higher_order_ops.auto_functionalize.auto_functionalized
with a direct call to the inplace custom op.
# TODO: check if PyTorch nightly has fixed this issue
"""
# debug code, if we want to see the graph before the transformation
# with open("before.py", "w") as f:
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
nodes_to_remove = []
for node in graph.nodes:
# Identify the auto_functionalized node
if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa
if node.args[0] == torch.ops._C.rotary_embedding.default:
# manual replace for rotary_embedding
# Now, collect the arguments
kwargs = node.kwargs
query = kwargs['query']
mm_node = query.args[0].args[0]
# Create a new call to torch.ops._C.rotary_embedding.default
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(torch.ops._C.rotary_embedding.default,
kwargs=kwargs)
# Remove the auto_functionalized node
# Since the node may have outputs, we need to handle its users
# Replace uses of the outputs (getitem nodes) with mm_node
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
# Remove the getitem node
for getitem_user in list(user.users):
if (getitem_user.op == 'call_function'
and getitem_user.target
== torch.ops.aten.slice_scatter.default):
# Replace the uses of slice_scatter node
# with mm_node
getitem_user.replace_all_uses_with(mm_node)
nodes_to_remove.append(getitem_user)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif node.args[0] == torch.ops._C.fused_add_rms_norm.default:
# manual replace for fused_add_rms_norm
# this is the most effective optimization for llama
# failing to do this will result in many unnecessary copies
kwargs = node.kwargs
input = kwargs['input']
residual = kwargs['residual']
# Create a new call to torch.ops._C.rotary_embedding.default
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs)
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
# Remove the getitem node
if user.args[1] == 1:
replace_node = input
elif user.args[1] == 2:
replace_node = residual
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif node.args[0] == torch.ops._C.rms_norm.default:
# manual replace for rms_norm
kwargs = node.kwargs
input = kwargs['input']
out = kwargs['out']
weight = kwargs['weight']
epsilon = kwargs['epsilon']
# Create a new call to torch.ops._C.rotary_embedding.default
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.rms_norm.default,
args=(out, input, weight, epsilon),
)
replace_node = out
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
elif node.args[0] == torch.ops._C.silu_and_mul.default:
# manual replace for silu_and_mul
kwargs = node.kwargs
input = kwargs['input']
out = kwargs['out']
# Create a new call to torch.ops._C.rotary_embedding.default
# cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa
with graph.inserting_before(node):
# just insert the call to the custom op
# NOTE: don't run dead code elimination,
# otherwise this op will be removed
graph.call_function(
torch.ops._C.silu_and_mul.default,
args=(out, input),
)
replace_node = out
for user in list(node.users):
if user.op == 'call_function' and user.target == operator.getitem: # noqa
user.replace_all_uses_with(replace_node)
nodes_to_remove.append(user)
nodes_to_remove.append(node)
# Remove the nodes all at once
for node in nodes_to_remove:
graph.erase_node(node)
# debug code, if we want to see the graph after the transformation
# with open("after.py", "w") as f:
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
def vllm_backend(graph, example_inputs):
from torch._inductor import config
current_config = config.shallow_copy_dict()
from torch._inductor.compile_fx import compile_fx
current_config['post_grad_custom_post_pass'] = fix_functionalization
return compile_fx(graph, example_inputs, config_patches=current_config)
...@@ -16,8 +16,7 @@ from vllm.tracing import is_otel_available, otel_import_error_traceback ...@@ -16,8 +16,7 @@ from vllm.tracing import is_otel_available, otel_import_error_traceback
from vllm.transformers_utils.config import (ConfigFormat, get_config, from vllm.transformers_utils.config import (ConfigFormat, get_config,
get_hf_image_processor_config, get_hf_image_processor_config,
get_hf_text_config) get_hf_text_config)
from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes, from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory,
cuda_device_count_stateless, get_cpu_memory, is_cpu,
is_hip, is_neuron, is_openvino, is_xpu, is_hip, is_neuron, is_openvino, is_xpu,
print_warning_once) print_warning_once)
...@@ -52,6 +51,7 @@ _PP_SUPPORTED_MODELS = [ ...@@ -52,6 +51,7 @@ _PP_SUPPORTED_MODELS = [
"Qwen2ForCausalLM", "Qwen2ForCausalLM",
"Qwen2MoeForCausalLM", "Qwen2MoeForCausalLM",
"QWenLMHeadModel", "QWenLMHeadModel",
"Qwen2VLForConditionalGeneration",
] ]
...@@ -96,15 +96,15 @@ class ModelConfig: ...@@ -96,15 +96,15 @@ class ModelConfig:
enforce_eager: Whether to enforce eager execution. If True, we will enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode. disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid. If False, we will use CUDA graph and eager execution in hybrid.
If None, the user did not specify, so default to False - If None, the user did not specify, so default to False.
except for encoder/decoder models, which currently require
eager mode.
max_context_len_to_capture: Maximum context len covered by CUDA graphs. max_context_len_to_capture: Maximum context len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back When a sequence has context length larger than this, we fall back
to eager mode (DEPRECATED. Use max_seq_len_to_capture instead). to eager mode (DEPRECATED. Use max_seq_len_to_capture instead).
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back When a sequence has context length larger than this, we fall back
to eager mode to eager mode. Additionally for encoder-decoder models, if the
sequence length of the encoder input is larger than this, we fall
back to the eager mode.
disable_sliding_window: Whether to disable sliding window. If True, disable_sliding_window: Whether to disable sliding window. If True,
we will disable the sliding window functionality of the model. we will disable the sliding window functionality of the model.
If the model does not support sliding window, this argument is If the model does not support sliding window, this argument is
...@@ -123,6 +123,8 @@ class ModelConfig: ...@@ -123,6 +123,8 @@ class ModelConfig:
can not be gathered from the vllm arguments. can not be gathered from the vllm arguments.
config_format: The config format which shall be loaded. config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'. Defaults to 'auto' which defaults to 'hf'.
mm_processor_kwargs: Arguments to be forwarded to the model's processor
for multi-modal data, e.g., image processor.
""" """
def __init__(self, def __init__(self,
...@@ -151,7 +153,8 @@ class ModelConfig: ...@@ -151,7 +153,8 @@ class ModelConfig:
limit_mm_per_prompt: Optional[Mapping[str, int]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
use_async_output_proc: bool = True, use_async_output_proc: bool = True,
override_neuron_config: Optional[Dict[str, Any]] = None, override_neuron_config: Optional[Dict[str, Any]] = None,
config_format: ConfigFormat = ConfigFormat.AUTO) -> None: config_format: ConfigFormat = ConfigFormat.AUTO,
mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
...@@ -185,33 +188,10 @@ class ModelConfig: ...@@ -185,33 +188,10 @@ class ModelConfig:
self.model, revision) self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
self.use_async_output_proc = use_async_output_proc self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs
# Choose a default enforce_eager value if the user did not specify # Set enforce_eager to False if the value is unset.
# a value (enforce_eager is None) if self.enforce_eager is None:
if getattr(self.hf_config, 'is_encoder_decoder', False):
if self.enforce_eager is None:
# *Only for encoder/decoder models* and
# *only if enforce_eager is unset*, override
# to enforce_eager=True
#
# Add a logger message since it is *somewhat* non-intuitive that
# enforce_eager is True when the user has not specified its
# value.
logger.info("Forcing enforce_eager == True because "
"enforce_eager setting was unspecified and "
"CUDAGraph is not supported with encoder/ "
"decoder models.")
self.enforce_eager = True
if not self.enforce_eager:
# Eager mode explicitly disabled by user for an encoder/
# decoder model; however CUDAGRAPH + encoder/decoder is
# not currently supported
raise ValueError(STR_NOT_IMPL_ENC_DEC_CUDAGRAPH)
elif self.enforce_eager is None:
# *Only for decoder-only models*, enforce_eager
# defaults to False if unset. This is intuitive
# so no logging message needed.
self.enforce_eager = False self.enforce_eager = False
if (not self.disable_sliding_window if (not self.disable_sliding_window
...@@ -242,6 +222,7 @@ class ModelConfig: ...@@ -242,6 +222,7 @@ class ModelConfig:
self._verify_embedding_mode() self._verify_embedding_mode()
self._verify_quantization() self._verify_quantization()
self._verify_cuda_graph() self._verify_cuda_graph()
self._verify_bnb_config()
def _init_multimodal_config( def _init_multimodal_config(
self, limit_mm_per_prompt: Optional[Mapping[str, int]] self, limit_mm_per_prompt: Optional[Mapping[str, int]]
...@@ -280,7 +261,13 @@ class ModelConfig: ...@@ -280,7 +261,13 @@ class ModelConfig:
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS] supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["awq", "gptq"] # "fp8" rocm_supported_quantization = [
"awq", "gptq", "compressed-tensors"
]
# rocm_supported_quantization = [
# "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
# "fbgemm_fp8"
# ]
optimized_quantization_methods = [ optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors", "awq_marlin", "fbgemm_fp8", "compressed_tensors",
...@@ -354,6 +341,28 @@ class ModelConfig: ...@@ -354,6 +341,28 @@ class ModelConfig:
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len) self.max_model_len)
def _verify_bnb_config(self) -> None:
"""
The current version of bitsandbytes (0.44.0) with 8-bit models does not
yet support CUDA graph.
"""
is_bitsandbytes = self.quantization == "bitsandbytes"
has_quantization_config = (getattr(self.hf_config,
"quantization_config", None)
is not None)
is_8bit = (self.hf_config.quantization_config.get(
"load_in_8bit", False) if has_quantization_config else False)
if all([
is_bitsandbytes,
has_quantization_config,
is_8bit,
not self.enforce_eager,
]):
logger.warning(
"CUDA graph is not supported on BitAndBytes 8bit yet, "
"fallback to the eager mode.")
self.enforce_eager = True
def verify_async_output_proc(self, parallel_config, speculative_config, def verify_async_output_proc(self, parallel_config, speculative_config,
device_config) -> None: device_config) -> None:
if not self.use_async_output_proc: if not self.use_async_output_proc:
...@@ -379,7 +388,7 @@ class ModelConfig: ...@@ -379,7 +388,7 @@ class ModelConfig:
self.use_async_output_proc = False self.use_async_output_proc = False
return return
if self.enforce_eager: if device_config.device_type == "cuda" and self.enforce_eager:
logger.warning( logger.warning(
"To see benefits of async output processing, enable CUDA " "To see benefits of async output processing, enable CUDA "
"graph. Since, enforce-eager is enabled, async output " "graph. Since, enforce-eager is enabled, async output "
...@@ -418,19 +427,6 @@ class ModelConfig: ...@@ -418,19 +427,6 @@ class ModelConfig:
"Pipeline parallelism is only supported for the following " "Pipeline parallelism is only supported for the following "
f" architectures: {_PP_SUPPORTED_MODELS}.") f" architectures: {_PP_SUPPORTED_MODELS}.")
if self.quantization == "bitsandbytes" and (
parallel_config.tensor_parallel_size > 1
or parallel_config.pipeline_parallel_size > 1):
raise ValueError(
"BitAndBytes quantization with TP or PP is not supported yet.")
# Remove the constraint after the bitsandbytes issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
if self.quantization == "bitsandbytes" and self.enforce_eager is False:
logger.warning("CUDA graph is not supported on BitAndBytes yet, "
"fallback to the eager mode.")
self.enforce_eager = True
if pipeline_parallel_size > 1 and self.use_async_output_proc: if pipeline_parallel_size > 1 and self.use_async_output_proc:
logger.warning("Async output processor is not supported with " logger.warning("Async output processor is not supported with "
"pipeline parallelism currently. Disabling it.") "pipeline parallelism currently. Disabling it.")
...@@ -583,7 +579,9 @@ class ModelConfig: ...@@ -583,7 +579,9 @@ class ModelConfig:
@property @property
def is_encoder_decoder_model(self) -> bool: def is_encoder_decoder_model(self) -> bool:
"""Extract the HF encoder/decoder model flag.""" """Extract the HF encoder/decoder model flag."""
return getattr(self.hf_config, "is_encoder_decoder", False) return getattr(self.hf_config, "is_encoder_decoder", False) or (
(hasattr(self.hf_config, "text_config") and getattr(
self.hf_config.text_config, "is_encoder_decoder", False)))
@property @property
def is_embedding_model(self) -> bool: def is_embedding_model(self) -> bool:
...@@ -968,7 +966,7 @@ class SchedulerConfig: ...@@ -968,7 +966,7 @@ class SchedulerConfig:
workers instead of an entire data. It should be enabled only workers instead of an entire data. It should be enabled only
when SPMD worker architecture is enabled. I.e., when SPMD worker architecture is enabled. I.e.,
VLLM_USE_RAY_SPMD_WORKER=1 VLLM_USE_RAY_SPMD_WORKER=1
policy: The scheduling policy to use. "fcfs" (default) or "priority".
""" """
def __init__(self, def __init__(self,
...@@ -983,7 +981,9 @@ class SchedulerConfig: ...@@ -983,7 +981,9 @@ class SchedulerConfig:
is_multimodal_model: bool = False, is_multimodal_model: bool = False,
preemption_mode: Optional[str] = None, preemption_mode: Optional[str] = None,
num_scheduler_steps: int = 1, num_scheduler_steps: int = 1,
send_delta_data: bool = False) -> None: multi_step_stream_outputs: bool = False,
send_delta_data: bool = False,
policy: str = "fcfs") -> None:
if max_num_batched_tokens is None: if max_num_batched_tokens is None:
if enable_chunked_prefill: if enable_chunked_prefill:
# It is the values that have the best balance between ITL # It is the values that have the best balance between ITL
...@@ -1023,7 +1023,9 @@ class SchedulerConfig: ...@@ -1023,7 +1023,9 @@ class SchedulerConfig:
self.embedding_mode = embedding_mode self.embedding_mode = embedding_mode
self.preemption_mode = preemption_mode self.preemption_mode = preemption_mode
self.num_scheduler_steps = num_scheduler_steps self.num_scheduler_steps = num_scheduler_steps
self.multi_step_stream_outputs = multi_step_stream_outputs
self.send_delta_data = send_delta_data self.send_delta_data = send_delta_data
self.policy = policy
self._verify_args() self._verify_args()
def _verify_args(self) -> None: def _verify_args(self) -> None:
...@@ -1066,20 +1068,20 @@ class DeviceConfig: ...@@ -1066,20 +1068,20 @@ class DeviceConfig:
def __init__(self, device: str = "auto") -> None: def __init__(self, device: str = "auto") -> None:
if device == "auto": if device == "auto":
# Automated device type detection # Automated device type detection
if is_neuron(): if current_platform.is_cuda_alike():
self.device_type = "cuda"
elif is_neuron():
self.device_type = "neuron" self.device_type = "neuron"
elif is_openvino(): elif is_openvino():
self.device_type = "openvino" self.device_type = "openvino"
elif current_platform.is_tpu(): elif current_platform.is_tpu():
self.device_type = "tpu" self.device_type = "tpu"
elif is_cpu(): elif current_platform.is_cpu():
self.device_type = "cpu" self.device_type = "cpu"
elif is_xpu(): elif is_xpu():
self.device_type = "xpu" self.device_type = "xpu"
else: else:
# We don't call torch.cuda.is_available() here to raise RuntimeError("Failed to infer device type")
# avoid initializing CUDA before workers are forked
self.device_type = "cuda"
else: else:
# Device type is assigned explicitly # Device type is assigned explicitly
self.device_type = device self.device_type = device
......
...@@ -417,9 +417,7 @@ class PrefixCachingBlockAllocator(BlockAllocator): ...@@ -417,9 +417,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
def is_block_cached(self, block: Block) -> bool: def is_block_cached(self, block: Block) -> bool:
assert block.content_hash is not None assert block.content_hash is not None
if block.content_hash in self._cached_blocks: return block.content_hash in self._cached_blocks
return True
return False
def promote_to_immutable_block(self, block: Block) -> BlockId: def promote_to_immutable_block(self, block: Block) -> BlockId:
"""Once a mutable block is full, it can be promoted to an immutable """Once a mutable block is full, it can be promoted to an immutable
......
...@@ -399,9 +399,7 @@ class BlockSpaceManagerV2(BlockSpaceManager): ...@@ -399,9 +399,7 @@ class BlockSpaceManagerV2(BlockSpaceManager):
""" """
alloc_status = self._can_swap(seq_group, Device.CPU, alloc_status = self._can_swap(seq_group, Device.CPU,
SequenceStatus.RUNNING) SequenceStatus.RUNNING)
if alloc_status == AllocStatus.OK: return alloc_status == AllocStatus.OK
return True
return False
def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]: def swap_out(self, seq_group: SequenceGroup) -> List[Tuple[int, int]]:
"""Returns the block id mapping (from GPU to CPU) generated by """Returns the block id mapping (from GPU to CPU) generated by
......
...@@ -766,6 +766,79 @@ class Scheduler: ...@@ -766,6 +766,79 @@ class Scheduler:
else: else:
return prompt_limit return prompt_limit
def _get_priority(self,
seq_group: SequenceGroup) -> Tuple[Optional[int], float]:
""" Get the priority of the sequence group.
Highest preference to user-defined priority, followed by arrival time.
Args:
seq_group: The sequence group input.
Returns:
The priority of the sequence group.
"""
return seq_group.priority, seq_group.arrival_time
def _schedule_priority_preemption(
self,
budget: SchedulingBudget,
) -> int:
"""Sorts waiting and running queue. Also, force preempt requests
from the running queue if their priority is lower.
Priority-based preemption is used with the priority policy.
Args:
budget: The scheduling budget. The argument is in-place updated
when any requests are scheduled.
Returns:
A count of priority-based preemptions.
"""
waiting_queue = self.waiting
running_queue = deque(sorted(self.running, key=self._get_priority))
blocks_to_swap_out: List[Tuple[int, int]] = []
force_preemption_count = 0
if waiting_queue:
seq_group = waiting_queue.popleft()
num_new_seqs = seq_group.get_max_num_running_seqs()
num_new_tokens = self._get_num_new_tokens(seq_group,
SequenceStatus.WAITING,
False, budget)
#Only preempt if priority inversion exists
while running_queue and self._get_priority(
running_queue[-1]) > self._get_priority(seq_group):
#Only preempt if waiting sequence cannot be allocated
can_allocate = self.block_manager.can_allocate(seq_group)
if (num_new_tokens and can_allocate == AllocStatus.OK
and budget.can_schedule(num_new_tokens=num_new_tokens,
num_new_seqs=num_new_seqs)):
break
#Adjust budget to remove the victim sequence group
vseq_group = running_queue.pop()
num_running_tokens = self._get_num_new_tokens(
vseq_group, SequenceStatus.RUNNING, False, budget)
budget.subtract_num_batched_tokens(vseq_group.request_id,
num_running_tokens)
num_running_seqs = vseq_group.get_max_num_running_seqs()
budget.subtract_num_seqs(vseq_group.request_id,
num_running_seqs)
#Preempt out the victim sequence group
self._preempt(vseq_group, blocks_to_swap_out,
PreemptionMode.RECOMPUTE)
waiting_queue.appendleft(vseq_group)
force_preemption_count += 1
#Put the sequence back into the waiting queue
waiting_queue.appendleft(seq_group)
waiting_queue = deque(sorted(waiting_queue, key=self._get_priority))
self.waiting = waiting_queue
self.running = running_queue
return force_preemption_count
def _schedule_prefills( def _schedule_prefills(
self, self,
budget: SchedulingBudget, budget: SchedulingBudget,
...@@ -917,6 +990,10 @@ class Scheduler: ...@@ -917,6 +990,10 @@ class Scheduler:
curr_loras, curr_loras,
enable_chunking=False) enable_chunking=False)
if len(prefills.seq_groups
) == 0 and self.scheduler_config.policy == "priority":
self._schedule_priority_preemption(budget)
# Don't schedule decodes if prefills are scheduled. # Don't schedule decodes if prefills are scheduled.
# NOTE: If `_schedule_prefills` doesn't enable chunking, self.running # NOTE: If `_schedule_prefills` doesn't enable chunking, self.running
# only contains decode requests, not chunked prefills. # only contains decode requests, not chunked prefills.
...@@ -1477,14 +1554,14 @@ class Scheduler: ...@@ -1477,14 +1554,14 @@ class Scheduler:
# the number of new tokens that is dividable by the block size # the number of new tokens that is dividable by the block size
# to avoid partial block matching. # to avoid partial block matching.
block_size = self.cache_config.block_size block_size = self.cache_config.block_size
reminder = budget.token_budget % block_size remainder = budget.token_budget % block_size
if reminder != 0: if remainder != 0:
raise ValueError("When enabling chunked prefill and " raise ValueError("When enabling chunked prefill and "
"prefix caching, max_num_batched_tokens " "prefix caching, max_num_batched_tokens "
"(chunk size) must be dividable by " "(chunk size) must be dividable by "
"block size, but got chunk_size " "block size, but got chunk_size "
f"({budget.token_budget}) % block_size " f"({budget.token_budget}) % block_size "
f"({block_size}) = {reminder}") f"({block_size}) = {remainder}")
if remaining_token_budget < num_new_tokens: if remaining_token_budget < num_new_tokens:
num_new_tokens = (remaining_token_budget // num_new_tokens = (remaining_token_budget //
block_size) * block_size block_size) * block_size
......
...@@ -38,6 +38,12 @@ def _can_p2p(rank: int, world_size: int) -> bool: ...@@ -38,6 +38,12 @@ def _can_p2p(rank: int, world_size: int) -> bool:
return True return True
def is_weak_contiguous(inp: torch.Tensor):
return inp.is_contiguous() or (inp.storage().nbytes() -
inp.storage_offset() * inp.element_size()
== inp.numel() * inp.element_size())
class CustomAllreduce: class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
...@@ -229,8 +235,19 @@ class CustomAllreduce: ...@@ -229,8 +235,19 @@ class CustomAllreduce:
ops.register_graph_buffers(self._ptr, handles, offsets) ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor): def should_custom_ar(self, inp: torch.Tensor):
return ops.should_custom_ar(inp, self.max_size, self.world_size, if self.disabled:
self.full_nvlink) return False
inp_size = inp.numel() * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size
return False
# all reduce, assuming inp tensor is IPC registered with register_buffer, # all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers # or, in the context of cuda graphs, register_graph_buffers
......
...@@ -9,11 +9,12 @@ from unittest.mock import patch ...@@ -9,11 +9,12 @@ from unittest.mock import patch
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from zmq import IPV6 # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import get_ip, get_open_port from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
...@@ -196,7 +197,9 @@ class MessageQueue: ...@@ -196,7 +197,9 @@ class MessageQueue:
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
self.local_socket.setsockopt(XPUB_VERBOSE, True) self.local_socket.setsockopt(XPUB_VERBOSE, True)
local_subscribe_port = get_open_port() local_subscribe_port = get_open_port()
self.local_socket.bind(f"tcp://*:{local_subscribe_port}") socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}"
logger.debug("Binding to %s", socket_addr)
self.local_socket.bind(socket_addr)
self.current_idx = 0 self.current_idx = 0
...@@ -212,7 +215,10 @@ class MessageQueue: ...@@ -212,7 +215,10 @@ class MessageQueue:
self.remote_socket = context.socket(XPUB) self.remote_socket = context.socket(XPUB)
self.remote_socket.setsockopt(XPUB_VERBOSE, True) self.remote_socket.setsockopt(XPUB_VERBOSE, True)
remote_subscribe_port = get_open_port() remote_subscribe_port = get_open_port()
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}") if is_valid_ipv6_address(connect_ip):
self.remote_socket.setsockopt(IPV6, 1)
socket_addr = f"tcp://*:{remote_subscribe_port}"
self.remote_socket.bind(socket_addr)
else: else:
remote_subscribe_port = None remote_subscribe_port = None
...@@ -255,8 +261,9 @@ class MessageQueue: ...@@ -255,8 +261,9 @@ class MessageQueue:
self.local_socket = context.socket(SUB) self.local_socket = context.socket(SUB)
self.local_socket.setsockopt_string(SUBSCRIBE, "") self.local_socket.setsockopt_string(SUBSCRIBE, "")
self.local_socket.connect( socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}"
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}") logger.debug("Connecting to %s", socket_addr)
self.local_socket.connect(socket_addr)
self.remote_socket = None self.remote_socket = None
else: else:
...@@ -270,8 +277,11 @@ class MessageQueue: ...@@ -270,8 +277,11 @@ class MessageQueue:
self.remote_socket = context.socket(SUB) self.remote_socket = context.socket(SUB)
self.remote_socket.setsockopt_string(SUBSCRIBE, "") self.remote_socket.setsockopt_string(SUBSCRIBE, "")
self.remote_socket.connect( if is_valid_ipv6_address(handle.connect_ip):
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}") self.remote_socket.setsockopt(IPV6, 1)
socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
logger.debug("Connecting to %s", socket_addr)
self.remote_socket.connect(socket_addr)
return self return self
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment