Unverified Commit e7596371 authored by weiyu's avatar weiyu Committed by GitHub
Browse files

[Refactor][TPU] Remove torch_xla path and use tpu-inference (#30808)


Signed-off-by: default avatarWei-Yu Lin <weiyulin@google.com>
Signed-off-by: default avatarweiyu <62784299+weiyu0824@users.noreply.github.com>
parent 0dd5dee9
This diff is collapsed.
...@@ -69,7 +69,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta): ...@@ -69,7 +69,6 @@ class AttentionBackendEnum(Enum, metaclass=_AttentionBackendEnumMeta):
"vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend" "vllm.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"
) )
FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend" FLASH_ATTN_MLA = "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
PALLAS = "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend" IPEX = "vllm.v1.attention.backends.ipex.IpexAttentionBackend"
NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend" NO_ATTENTION = "vllm.v1.attention.backends.no_attention.NoAttentionBackend"
FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" FLEX_ATTENTION = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
......
...@@ -227,28 +227,3 @@ class MMEncoderAttention(CustomOp): ...@@ -227,28 +227,3 @@ class MMEncoderAttention(CustomOp):
"XPU only supports FLASH_ATTN for vision attention." "XPU only supports FLASH_ATTN for vision attention."
) )
return self._forward_fa(query, key, value, cu_seqlens, max_seqlen) return self._forward_fa(query, key, value, cu_seqlens, max_seqlen)
def forward_tpu(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
max_seqlen: torch.Tensor | None = None, # Only used for Flash Attention
) -> torch.Tensor:
assert self.attn_backend == AttentionBackendEnum.PALLAS, (
f"MMEncoderAttention on TPU only supports PALLAS backend, "
f"but got {self.attn_backend}."
)
if cu_seqlens is None:
query, key, value = (x.transpose(1, 2) for x in (query, key, value))
from torch_xla.experimental.custom_kernel import flash_attention
out = flash_attention(query, key, value, sm_scale=self.scale)
out = out.transpose(1, 2)
return out
logger.warning_once(
"PALLAS backend with cu_seqlens is not supported for ViT yet. ",
"Falling back to SDPA implementation.",
)
return self._forward_sdpa(query, key, value, cu_seqlens)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import torch
from torch.distributed import ProcessGroup
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms.tpu import USE_TPU_INFERENCE
from .base_device_communicator import DeviceCommunicatorBase
USE_RAY = parallel_config = (
get_current_vllm_config().parallel_config.distributed_executor_backend == "ray"
)
logger = init_logger(__name__)
if not USE_TPU_INFERENCE:
logger.info("tpu_inference not found, using vLLM's TpuCommunicator")
if current_platform.is_tpu():
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from torch_xla._internal import pjrt
from torch_xla.distributed.xla_multiprocessing import (
create_optimized_replica_groups,
)
if USE_RAY:
from vllm.v1.executor import ray_utils
class TpuCommunicator(DeviceCommunicatorBase):
def __init__(
self,
cpu_group: ProcessGroup,
device: torch.device | None = None,
device_group: ProcessGroup | None = None,
unique_name: str = "",
):
super().__init__(cpu_group, device, device_group, unique_name)
# NOTE(woosuk): When using TP > 1 on TPUs, every TPU on the same node
# must be used together. Therefore, the local rank and world size can
# be simply calculated as follows.
global_rank = self.global_rank
global_world_size = self.global_world_size
if USE_RAY:
logger.info("TpuCommunicator initialized with RAY")
# Calculate how many TPU nodes are in the current deployment. This
# is the Ray placement group if it is deployed with Ray. Default
# to the number of TPU nodes in the Ray cluster. The number of TPU
# nodes is computed by the total number of TPUs divided by the
# number of TPU accelerators per node, to account for clusters
# with both CPUs and TPUs.
num_nodes = ray_utils.get_num_tpu_nodes()
num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
if num_nodes_in_pg > 0:
num_nodes = num_nodes_in_pg
local_world_size = global_world_size // num_nodes
local_rank = global_rank % local_world_size
else:
logger.info("TpuCommunicator initialized with MP")
# Sanity: Verify we run on a single host
num_hosts = torch_xla.tpu.num_tpu_workers()
assert num_hosts == 1
# Get the current number of TPUs (we have locally)
local_world_size = torch_xla.tpu.num_available_chips()
# Get current rank
local_rank = global_rank % local_world_size
# Ensure environment variables are set for multihost deployments.
# On GKE, this is needed for libtpu and TPU driver to know which TPU
# chip is actually visible. Otherwise the TPU driver will fail to
# initialize because the number of devices would be different from
# the number of visible worker addresses.
os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)
pjrt.initialize_multiprocess(local_rank, local_world_size)
xr._init_world_size_ordinal()
self.groups = create_optimized_replica_groups()
def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
# TODO: Remove the groups specification after XLA compiler can support
# auto-reordering the ring order for all-reduce.
return xm.all_reduce(xm.REDUCE_SUM, input_, groups=self.groups)
def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
assert dim == -1, "TPUs only support dim=-1 for all-gather."
return xm.all_gather(input_, dim=dim)
...@@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Literal ...@@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Literal
import torch import torch
from vllm.attention.backends.abstract import AttentionBackend from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -251,9 +250,6 @@ class TpKVTopology: ...@@ -251,9 +250,6 @@ class TpKVTopology:
len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1 len(kv_cache_shape) == 5 and kv_cache_shape[0] == 1
) )
attn_backend = AttentionBackendEnum[self.attn_backend.get_name()]
self._use_pallas = attn_backend == AttentionBackendEnum.PALLAS
@property @property
def is_kv_layout_blocks_first(self) -> bool: def is_kv_layout_blocks_first(self) -> bool:
return self._is_kv_layout_blocks_first return self._is_kv_layout_blocks_first
...@@ -261,7 +257,7 @@ class TpKVTopology: ...@@ -261,7 +257,7 @@ class TpKVTopology:
@property @property
def split_k_and_v(self) -> bool: def split_k_and_v(self) -> bool:
# Whether to register regions for K and V separately (when present). # Whether to register regions for K and V separately (when present).
return not (self.is_mla or self._use_pallas or self.is_kv_layout_blocks_first) return not (self.is_mla or self.is_kv_layout_blocks_first)
@property @property
def tp_size(self) -> int: def tp_size(self) -> int:
......
...@@ -499,7 +499,6 @@ class MooncakeConnectorWorker: ...@@ -499,7 +499,6 @@ class MooncakeConnectorWorker:
total_num_kv_heads=self.model_config.get_total_num_kv_heads(), total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend, attn_backend=backend,
) )
self._use_pallas = self.kv_topo._use_pallas
self.zmq_ctx = zmq.Context() self.zmq_ctx = zmq.Context()
self.async_zmq_ctx = zmq.asyncio.Context() self.async_zmq_ctx = zmq.asyncio.Context()
......
...@@ -990,7 +990,6 @@ class NixlConnectorWorker: ...@@ -990,7 +990,6 @@ class NixlConnectorWorker:
total_num_kv_heads=self.model_config.get_total_num_kv_heads(), total_num_kv_heads=self.model_config.get_total_num_kv_heads(),
attn_backend=backend, attn_backend=backend,
) )
self._use_pallas = self.kv_topo._use_pallas
self._physical_blocks_per_logical_kv_block = 1 self._physical_blocks_per_logical_kv_block = 1
def _nixl_handshake( def _nixl_handshake(
...@@ -1648,9 +1647,6 @@ class NixlConnectorWorker: ...@@ -1648,9 +1647,6 @@ class NixlConnectorWorker:
# Num kv_heads > tp_size and P TP > D TP case, not supported # Num kv_heads > tp_size and P TP > D TP case, not supported
assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id)) assert not (tp_ratio < 0 and self.kv_topo.is_kv_replicated(remote_engine_id))
assert not self._use_pallas or tp_ratio == 1, (
"TPU (pallas_v1) DOES NOT support heterogeneous TP yet."
)
kv_cache_layout = ( kv_cache_layout = (
self.kv_cache_layout self.kv_cache_layout
if not self.use_host_buffer if not self.use_host_buffer
...@@ -1821,9 +1817,7 @@ class NixlConnectorWorker: ...@@ -1821,9 +1817,7 @@ class NixlConnectorWorker:
if len(self.device_kv_caches) == 0: if len(self.device_kv_caches) == 0:
return return
split_k_and_v = not ( split_k_and_v = not (self.use_mla or self.kv_topo.is_kv_layout_blocks_first)
self.use_mla or self._use_pallas or self.kv_topo.is_kv_layout_blocks_first
)
sample_cache = list(self.device_kv_caches.values())[0][0] sample_cache = list(self.device_kv_caches.values())[0][0]
for block_size_ratio, block_ids_list in block_ids_per_ratio.items(): for block_size_ratio, block_ids_list in block_ids_per_ratio.items():
assert block_size_ratio > 1, "Only nP < nD supported currently." assert block_size_ratio > 1, "Only nP < nD supported currently."
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections import OrderedDict
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_xla.distributed.spmd as xs
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
logger = init_logger(__name__)
class XlaQKVParallelLinear(nn.Module):
def __init__(self, qkv_linear: nn.Module, mesh: Optional["xs.Mesh"] = None):
super().__init__()
assert isinstance(qkv_linear, QKVParallelLinear)
self.skip_bias_add = qkv_linear.skip_bias_add
self.return_bias = qkv_linear.return_bias
assert qkv_linear.tp_size == 1, "TP > 1 is only supported under SPMD."
self.q_weight: Parameter
self.k_weight: Parameter
self.v_weight: Parameter
self.q_bias: Parameter | None
self.k_bias: Parameter | None
self.v_bias: Parameter | None
self._load_weights_from_qkv_linear(qkv_linear)
if mesh is not None:
self._shard_weight(mesh)
def _shard_weight(self, mesh: "xs.Mesh"):
self.q_weight = Parameter(self.q_weight.to("xla"), requires_grad=False)
self.k_weight = Parameter(self.k_weight.to("xla"), requires_grad=False)
self.v_weight = Parameter(self.v_weight.to("xla"), requires_grad=False)
xs.mark_sharding(self.q_weight, mesh, ("x", None))
xs.mark_sharding(self.k_weight, mesh, ("x", None))
xs.mark_sharding(self.v_weight, mesh, ("x", None))
if self.q_bias is not None:
assert self.k_bias is not None and self.v_bias is not None, (
"QKVParallelLinear should have q, k, and v biases together."
)
self.q_bias = Parameter(self.q_bias.to("xla"), requires_grad=False)
xs.mark_sharding(self.q_bias, mesh, ("x",))
self.k_bias = Parameter(self.k_bias.to("xla"), requires_grad=False)
xs.mark_sharding(self.k_bias, mesh, ("x",))
self.v_bias = Parameter(self.v_bias.to("xla"), requires_grad=False)
xs.mark_sharding(self.v_bias, mesh, ("x",))
def _load_weights_from_qkv_linear(self, qkv_linear: nn.Module):
q_proj_size, k_proj_size, _ = qkv_linear.output_sizes
# The weight of qkv linear is a concatenation of q, k, and v weights
# along the output dimension.
qkv_weight = qkv_linear.weight.data.cpu()
q_weight = Parameter(qkv_weight[:q_proj_size], requires_grad=False)
k_weight = Parameter(
qkv_weight[q_proj_size : q_proj_size + k_proj_size], requires_grad=False
)
v_weight = Parameter(
qkv_weight[q_proj_size + k_proj_size :], requires_grad=False
)
self.register_parameter("q_weight", q_weight)
self.register_parameter("k_weight", k_weight)
self.register_parameter("v_weight", v_weight)
if qkv_linear.bias is not None:
q_bias = Parameter(qkv_linear.bias[:q_proj_size], requires_grad=False)
k_bias = Parameter(
qkv_linear.bias[q_proj_size : q_proj_size + k_proj_size],
requires_grad=False,
)
v_bias = Parameter(
qkv_linear.bias[q_proj_size + k_proj_size :], requires_grad=False
)
self.register_parameter("q_bias", q_bias)
self.register_parameter("k_bias", k_bias)
self.register_parameter("v_bias", v_bias)
else:
self.register_parameter("q_bias", None)
self.register_parameter("k_bias", None)
self.register_parameter("v_bias", None)
def forward(self, input):
# Same forward functionality as QKVParallelLinear, but doing qkv porj
# separately.
q_bias = self.q_bias if not self.skip_bias_add else None
k_bias = self.k_bias if not self.skip_bias_add else None
v_bias = self.v_bias if not self.skip_bias_add else None
q_proj = F.linear(input, self.q_weight, q_bias)
k_proj = F.linear(input, self.k_weight, k_bias)
v_proj = F.linear(input, self.v_weight, v_bias)
# The q/k/v projections will be split outside of the QKVParallelLinear.
# Because we are replacing XlaQKVParallelLinear with the
# QKVParallelLinear, we need to concatenate q, k, and v projections to
# match the output shape of the QKVParallelLinear implementation even if
# it seems to be redundant.
# The concat and the following split will be noop, and should be
# optimized away by the compiler.
qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=-1)
output_bias = (
torch.cat([q_bias, k_bias, v_bias], dim=-1) if self.skip_bias_add else None
)
if not self.return_bias:
return qkv_proj
return qkv_proj, output_bias
def partition_column_parallel_linear(
layer: torch.nn.Module, mesh: xs.Mesh
) -> torch.nn.Module:
assert isinstance(layer, ColumnParallelLinear)
xs.mark_sharding(layer.weight, mesh, ("x", None))
logger.debug("Applied column-parallel sharding to %s", layer)
return layer
def partition_row_parallel_linear(
layer: torch.nn.Module, mesh: xs.Mesh
) -> torch.nn.Module:
assert isinstance(layer, RowParallelLinear)
xs.mark_sharding(layer.weight, mesh, (None, "x"))
logger.debug("Applied row-parallel sharding to %s", layer)
return layer
def partition_qkv_parallel_linear(
layer: torch.nn.Module, mesh: xs.Mesh
) -> torch.nn.Module:
assert isinstance(layer, QKVParallelLinear)
xla_layer = XlaQKVParallelLinear(layer, mesh)
logger.debug("Applied qkv parallel sharding to %s", layer)
return xla_layer
MODULE_TYPE_TO_WRAPPING_FUNC = OrderedDict(
[
("QKVParallelLinear", partition_qkv_parallel_linear),
("ColumnParallelLinear", partition_column_parallel_linear),
("RowParallelLinear", partition_row_parallel_linear),
]
)
def get_fqn(module):
# Get the fully qualified name of the module
return module.__class__.__qualname__
def shard_model(model: torch.nn.Module, mesh: "xs.Mesh") -> None:
"""
Recursively check a PyTorch model and apply appropriate sharding based on
the MODULE_TYPE_TO_WRAPPING_FUNC mapping.
Args:
model: torch.nn.Module to process
mesh: An XLA SPMD mesh object used for sharding
"""
def _process_module(module, name=None, parent=None):
for module_type, wrapping_func in MODULE_TYPE_TO_WRAPPING_FUNC.items():
if get_fqn(module) == module_type:
wrapped_module = wrapping_func(module, mesh)
assert parent is not None and name is not None, (
"Top Level module is not expected to be wrapped."
)
if wrapped_module is not module:
# Wrapped module and module are different py object.
# The original module should be replaced by the
# wrapped_module.
logger.debug("replace %s with %s", module, wrapped_module)
setattr(parent, name, wrapped_module)
module = wrapped_module
break
for child_name, child_module in list(module.named_children()):
_process_module(child_module, child_name, module)
_process_module(model)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.ops.xla_ops.lora_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
__all__ = ["bgmv_expand", "bgmv_expand_slice", "bgmv_shrink"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import jax
import jax.numpy as jnp
import torch
import torch.nn.functional as F
import torch_xla.core.xla_builder as xb
from torch.library import impl
from torch_xla.experimental.custom_kernel import XLA_LIB, jax_import_guard
@jax.jit
def bgmv_jax(inputs, loras, idxs):
return jnp.einsum(
"td,tX,Xld->tl",
inputs,
jax.nn.one_hot(idxs, loras.shape[0], dtype=inputs.dtype),
loras,
)
XLA_LIB.define("bgmv(Tensor inputs, Tensor loras, Tensor idxs) -> Tensor")
@impl(XLA_LIB, "bgmv", "XLA")
def bgmv_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
jax_import_guard()
return xb.call_jax(bgmv_jax, (inputs, loras, idxs))
@impl(XLA_LIB, "bgmv", "CompositeExplicitAutograd")
def bgmv_non_xla(inputs: torch.Tensor, loras: torch.Tensor, idxs: torch.IntTensor):
T, _ = inputs.shape
if len(loras.shape) == 4:
loras = loras.squeeze(axis=1)
_, L, _ = loras.shape
return torch.empty((T, L), device=inputs.device)
def bgmv_expand(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
add_inputs: bool = True,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): output tensor of shape
[num_tokens, hidden_size * num_slices].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
add_inputs (bool): Whether or not to add the input tensor to the output
tensor.
"""
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
limit = output_tensor.shape[0]
if outputs.shape[0] == 1 and output_tensor.shape[0] != 1:
limit = 1
if output_tensor.shape[1] > outputs.shape[1]:
outputs = F.pad(outputs, (0, output_tensor.shape[1] - outputs.shape[1], 0, 0))
if add_inputs:
return output_tensor + outputs[:limit, : output_tensor.shape[1]]
else:
return outputs[:limit, : output_tensor.shape[1]]
def bgmv_shrink(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
lora_indices_tensor: torch.Tensor,
scaling: float = 1.0,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
scaling (float, optional): Scalar multiplier applied to the output.
"""
return scaling * torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
def bgmv_expand_slice(
inputs: torch.Tensor,
lora_b_weights: torch.Tensor,
output_tensor: torch.Tensor,
lora_indices_tensor: torch.Tensor,
slice_offset: int,
slice_size: int,
add_inputs: bool = True,
):
"""
Args:
inputs (torch.Tensor): Input tensor of shape [num_tokens, hidden_size].
lora_b_weights (torch.Tensor): LoRA weights of shape
[num_loras, lora_rank, hidden_size].
output_tensor (torch.Tensor): output tensor of shape
[num_tokens, hidden_size * num_slices].
lora_indices_tensor (torch.Tensor): Tensor of shape [num_tokens]
indicating which LoRA matrix to use for each token.
add_inputs (bool): Whether or not to add the input tensor to the output
tensor.
"""
outputs = torch.ops.xla.bgmv(inputs, lora_b_weights, lora_indices_tensor)
outputs = F.pad(
outputs,
(
slice_offset,
output_tensor.shape[1] - (slice_offset + slice_size),
0,
0,
),
)
if add_inputs:
return output_tensor + outputs
else:
return outputs
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import TYPE_CHECKING
import torch
import torch.nn.functional as F
import torch_xla
from vllm.lora.ops.xla_ops import bgmv_expand, bgmv_expand_slice, bgmv_shrink
from vllm.lora.punica_wrapper.utils import convert_mapping
if TYPE_CHECKING:
# avoid circuit import
from vllm.lora.layers import LoRAMapping
from .punica_base import PunicaWrapperBase
class PunicaWrapperTPU(PunicaWrapperBase):
"""
PunicaWrapperTPU is designed to manage and provide metadata for the punica
kernel. The main function is to maintain the state information for
Multi-LoRA, and to provide the interface for the pytorch punica ops.
"""
def __init__(
self,
max_num_batched_tokens: int,
max_batches: int,
device: torch.device | str,
**kwargs,
):
PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device)
# PunicaWrapperBase defines some tensors with dtype=torch.int64, which
# isn't supported by the TPU. So convert those tensors to int32.
# Not all of them are used by the TPU so only convert the useful ones.
self._token_lora_indices = self._token_lora_indices.to(dtype=torch.int32)
self._sampler_indices = self._sampler_indices.to(dtype=torch.int32)
self._sampler_indices_padded = self._sampler_indices_padded.to(
dtype=torch.int32
)
torch.ops.xla.dynamo_set_buffer_donor_(self._token_lora_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._sampler_indices_padded, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._embeddings_indices, True)
torch.ops.xla.dynamo_set_buffer_donor_(self._lora_indices_per_batch, True)
torch._dynamo.mark_dynamic(self._token_lora_indices, 0)
torch._dynamo.mark_dynamic(self._embeddings_indices, 1)
torch._dynamo.mark_dynamic(self._sampler_indices_padded, 0)
def _get_token_lora_indices(self, x: torch.Tensor) -> torch.IntTensor:
return torch.narrow(self._token_lora_indices, 0, 0, x.size(0))
@property
def embeddings_indices(self) -> torch.Tensor:
"""
This property provides access to the indices used for lora embeddings,
specifically for VocabParallelEmbeddingWithLoRA.
"""
return self._embeddings_indices[:]
@property
def sampler_indices_padded(self) -> torch.Tensor:
"""
This property provides access to padded sampler indices.
"""
return self._sampler_indices_padded[:]
def shrink(
self,
x: torch.Tensor,
w_t_all: torch.Tensor,
scale: float,
):
return bgmv_shrink(x, w_t_all, self._get_token_lora_indices(x), scale)
def expand(
self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, add_inputs: bool
):
return bgmv_expand(x, w_t_all, y, self._get_token_lora_indices(x), add_inputs)
def expand_slice(
self,
y: torch.Tensor,
x: torch.Tensor,
w_t_all: torch.Tensor,
y_offset: int,
y_slice_size: int,
add_inputs: bool,
) -> torch.Tensor:
return bgmv_expand_slice(
x,
w_t_all,
y,
self._get_token_lora_indices(x),
y_offset,
y_slice_size,
add_inputs,
)
def add_shrink(
self,
y: tuple[torch.Tensor, ...] | torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
scale: float,
**kwargs,
) -> torch.Tensor | None:
"""
Performs GEMM for multiple slices of lora_a.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (x @ lora_a_stacked[i]) * scale
Args:
y (Union[tuple[torch.Tensor, ...], torch.Tensor]): Output tensors
x (torch.Tensor): Input tensor
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weights
scale (float): Scaling factor for the operation
"""
torch.ops.xla.dynamo_set_buffer_donor_(y, True)
x = x.view(-1, x.shape[-1])
for slice_idx in range(len(lora_a_stacked)):
lora_s = lora_a_stacked[slice_idx]
y_s = self.shrink(x, lora_s, scale)
y[slice_idx, :, :] = y_s # type: ignore[index]
return y
def add_expand(
self,
y: torch.Tensor,
x: tuple[torch.Tensor, ...] | torch.Tensor,
lora_b_stacked: tuple[torch.Tensor, ...],
output_slices: tuple[int, ...],
offset_start: int = 0,
add_inputs=True,
**kwargs,
) -> torch.Tensor:
"""
Performs GEMM for multiple slices of lora_b.
Semantics:
for i in range(len(lora_b_stacked)):
slice = output_slices[i]
y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i]
offset += slice
Args:
y (torch.Tensor): Output tensor.
x (Union[tuple[torch.Tensor, ...], torch.Tensor]): Input tensors
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight
output_slices (tuple[int, ...]): Every slice's size
add_inputs (bool): Defaults to True.
"""
y_org = y
y = y.view(-1, y.shape[-1])
offset_left = 0
for slice_idx in range(len(lora_b_stacked)):
y = self.expand_slice(
y,
x[slice_idx],
lora_b_stacked[slice_idx],
offset_left,
output_slices[slice_idx],
add_inputs=add_inputs,
)
offset_left += output_slices[slice_idx]
return y.view_as(y_org)
def add_lora_embedding(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_b_stacked: torch.Tensor,
add_inputs: bool = True,
**kwargs,
) -> torch.Tensor:
"""
Applies lora specifically for VocabParallelEmbeddingWithLoRA.
Semantics:
y += x @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_b_stacked (torch.Tensor): lora_b's weights.
add_inputs (bool): Default to True.
"""
# Embedding layer only needs the expand op
return self.expand(y, x, lora_b_stacked, add_inputs)
def add_lora_linear(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: tuple[torch.Tensor, ...],
lora_b_stacked: tuple[torch.Tensor, ...],
scale: float,
output_slices: tuple[int, ...],
*,
buffer: tuple[torch.Tensor, ...] | None = None,
**kwargs,
) -> torch.Tensor:
"""
Applicable to linear-related lora.
Semantics:
for i in range(len(lora_a_stacked)):
y[i] += (
x[i].unsqueeze(0)
@ lora_a_stacked[indices[i], layer_idx, :, :]
@ lora_b_stacked[indices[i], layer_idx, :, :]
* scale
).squeeze(0)
Args:
y (torch.Tensor): Output tensor. Will not be changed in-place.
x (torch.Tensor): Input tensor (T, E)
lora_a_stacked (tuple[torch.Tensor, ...]): lora_a's weight.
lora_b_stacked (tuple[torch.Tensor, ...]): lora_b's weight.
scale (float): Scaling factor.
output_slices (tuple[int, ...]): Every slice's size.
buffer (Optional[tuple[torch.Tensor, ...]]): Defaults to None.
"""
assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices)
if buffer is None:
r = lora_b_stacked[0].size(-1)
T = x.size(0)
buffer = torch.zeros(
(len(output_slices), T, r),
dtype=x.dtype,
device=x.device,
)
buffer = self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs)
return self.add_expand(
y, buffer, lora_b_stacked, output_slices, add_inputs=True, **kwargs
)
def add_lora_logits(
self,
y: torch.Tensor,
x: torch.Tensor,
lora_a_stacked: torch.Tensor,
lora_b_stacked: torch.Tensor,
scale,
*,
buffer: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
"""
Applies lora specifically for LogitsProcessorWithLoRA.
Semantics:
buffer = (x @ lora_a_stacked) * scale
y += buffer @ lora_b_stacked
Args:
y (torch.Tensor): Output tensor.
x (torch.Tensor): Input tensor.
lora_a_stacked (torch.Tensor): lora_a's weights.
lora_b_stacked (torch.Tensor):lora_b's weights.
scale (float): Scaling factor.
buffer (Optional[torch.Tensor]):Default to None.
"""
y_org = y
y = y.view(-1, y.shape[-1])
x = x.view(-1, x.shape[-1])
sampler_indices = torch.narrow(self._sampler_indices, 0, 0, x.size(0))
buffer = bgmv_shrink(x, lora_a_stacked, sampler_indices, scale)
y = bgmv_expand(buffer, lora_b_stacked, y, sampler_indices, add_inputs=True)
return y.view_as(y_org)
# This performs the same tensor ops as the base method, except it does them
# on the CPU then transfers the results to the TPU
def _update_base_metadata(
self,
mapping: "LoRAMapping",
lora_index_to_id: list[int | None],
max_loras: int,
vocab_size: int,
):
# Make sure we don't accidentally collect outside operations
torch_xla.sync()
# Pad the prompt mapping to avoid running into recompiles on the TPU
# TODO: Should this happen inside mapping internally? If so how can we
# avoid having backend specific LoRAMapping classes?
mapping.prompt_mapping = self._pad_prompt_mapping(mapping.prompt_mapping)
(
base_indices,
sampler_indices,
sampler_indices_padded,
embeddings_indices,
indices_len,
) = convert_mapping(
mapping,
lora_index_to_id,
max_loras,
vocab_size,
0, # extra_vocab_size
"cpu",
)
self._token_lora_indices = self._pad_to_shape(
base_indices, self._token_lora_indices.shape, dims=1
).to(self.device)
self._sampler_indices = self._pad_to_shape(
sampler_indices, self._sampler_indices.shape, dims=1
).to(self.device)
self._sampler_indices_padded = self._pad_to_shape(
sampler_indices_padded, self._sampler_indices_padded.shape, dims=1
).to(self.device)
self._embeddings_indices = self._pad_to_shape(
embeddings_indices, self._embeddings_indices.shape, dims=2
).to(self.device)
self.indices_len[:] = indices_len
def _update_prefill_metadata(self, token_lora_tensor: torch.Tensor) -> None:
self.batch_size = 1
self._lora_indices_per_batch[: self.batch_size] = token_lora_tensor[
: self.batch_size
]
def _pad_prompt_mapping(self, prompt_mapping: tuple[int, ...]) -> tuple[int, ...]:
num_reqs = len(prompt_mapping)
# From vllm/v1/worker/tpu_model_runner:51, but need to avoid a circular
# import
MIN_NUM_SEQS = 8
padded_num_reqs = max(2 ** math.ceil(math.log2(num_reqs)), MIN_NUM_SEQS)
pad_len = padded_num_reqs - num_reqs
padding = [-1] * pad_len
return tuple(list(prompt_mapping) + padding)
def _pad_to_shape(self, src, target_shape, dims=1):
if dims == 1:
pad_len = target_shape[0] - src.shape[0]
return F.pad(src, (0, pad_len), value=0).to(torch.int32)
else:
pad_rows = target_shape[0] - src.shape[0]
pad_cols = target_shape[1] - src.shape[1]
return F.pad(src, (0, pad_cols, 0, pad_rows), value=0).to(torch.int32)
...@@ -66,12 +66,6 @@ else: ...@@ -66,12 +66,6 @@ else:
eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record eplb_map_to_physical_and_record = _eplb_map_to_physical_and_record
from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk
if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore
from vllm.model_executor.layers.fused_moe.fused_moe_method_base import ( from vllm.model_executor.layers.fused_moe.fused_moe_method_base import (
FusedMoEMethodBase, FusedMoEMethodBase,
) )
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn.functional as F
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
"""
Compute the histogram of an int32 tensor. The bin edges are defined by the
min and max values, with step = 1.
"""
assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
assert min <= max, "min must be less than or equal to max."
def searchsorted(
sorted_sequence: torch.Tensor, values_to_search: torch.Tensor
) -> torch.Tensor:
return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1)
bin_edges = torch.linspace(min, max, max - min + 1, dtype=input.dtype).to(
input.device
)
return searchsorted(bin_edges, input).to(torch.int32)
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
global_num_experts: int,
expert_map: torch.Tensor = None,
renormalize: bool = False,
) -> torch.Tensor:
"""
Args:
hidden_states: [*, hidden_size]
w1: [num_experts, intermediate_size * 2, hidden_size]
w2: [num_experts, hidden_size, intermediate_size]
gating_output: [*, num_experts]
"""
assert expert_map is None, "expert_map is not supported for pallas MoE."
import torch_xla.experimental.custom_kernel # noqa: F401
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
num_experts = w1.shape[0]
intermediate_size = w2.shape[-1]
device = hidden_states.device
dtype = hidden_states.dtype
assert (num_tokens * topk) % 16 == 0, (
"The Pallas GMM kernel requires num_tokens * topk to be a multiple of "
f"16 but got {num_tokens * topk}"
)
hidden_states = hidden_states.view(num_tokens, hidden_size)
gating_output = gating_output.view(num_tokens, num_experts)
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
topk_weights, topk_indices = topk_weights.topk(topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(dtype)
topk_indices = topk_indices.flatten()
topk_argsort_indices = topk_indices.argsort()
topk_argsort_revert_indices = topk_argsort_indices.argsort()
token_indices = torch.arange(num_tokens, device=device).repeat_interleave(topk)
token_indices = token_indices[topk_argsort_indices]
group_sizes = _histogram(topk_indices.to(torch.int32), 0, num_experts - 1)
x = hidden_states[token_indices]
x = torch.ops.xla.gmm(x, w1, group_sizes, transpose_rhs=True)
x = F.silu(x[..., :intermediate_size]) * x[..., intermediate_size:]
x = torch.ops.xla.gmm(x, w2, group_sizes, transpose_rhs=True)
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
x = x * topk_weights.unsqueeze(dim=-1)
x = x.sum(dim=-2)
x = x.reshape(orig_shape)
return x
...@@ -47,10 +47,6 @@ if current_platform.is_cuda_alike(): ...@@ -47,10 +47,6 @@ if current_platform.is_cuda_alike():
else: else:
TritonExperts = None # type: ignore TritonExperts = None # type: ignore
if current_platform.is_tpu():
from .moe_pallas import fused_moe as fused_moe_pallas
else:
fused_moe_pallas = None # type: ignore
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -390,53 +386,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -390,53 +386,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=layer.custom_routing_function, custom_routing_function=layer.custom_routing_function,
) )
def forward_tpu( if current_platform.is_cpu():
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
assert not layer.use_grouped_topk
assert layer.num_expert_group is None
assert layer.topk_group is None
assert layer.custom_routing_function is None
assert layer.apply_router_weight_on_input is False
if layer.scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for TPU."
)
if layer.e_score_correction_bias is not None:
raise NotImplementedError(
"Expert score correction bias is not supported for TPU."
)
assert layer.activation == "silu", (
f"{layer.activation} is not supported for TPU."
)
assert layer.routed_scaling_factor == 1.0, (
f"routed_scaling_factor {layer.routed_scaling_factor} is "
"not supported for TPU."
)
if (
layer.enable_eplb is not False
or layer.expert_load_view is not None
or layer.logical_to_physical_map is not None
or layer.logical_replica_count is not None
):
raise NotImplementedError("Expert load balancing is not supported for TPU.")
return fused_moe_pallas(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk=layer.top_k,
gating_output=router_logits,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
renormalize=layer.renormalize,
)
if current_platform.is_tpu():
forward_native = forward_tpu
elif current_platform.is_cpu():
forward_native = forward_cpu forward_native = forward_cpu
elif current_platform.is_xpu(): elif current_platform.is_xpu():
forward_native = forward_xpu forward_native = forward_xpu
......
...@@ -11,7 +11,6 @@ logger = init_logger(__name__) ...@@ -11,7 +11,6 @@ logger = init_logger(__name__)
QuantizationMethods = Literal[ QuantizationMethods = Literal[
"awq", "awq",
"deepspeedfp", "deepspeedfp",
"tpu_int8",
"fp8", "fp8",
"ptpc_fp8", "ptpc_fp8",
"fbgemm_fp8", "fbgemm_fp8",
...@@ -129,12 +128,10 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -129,12 +128,10 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .ptpc_fp8 import PTPCFp8Config from .ptpc_fp8 import PTPCFp8Config
from .rtn import RTNConfig from .rtn import RTNConfig
from .torchao import TorchAOConfig from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig
method_to_config: dict[str, type[QuantizationConfig]] = { method_to_config: dict[str, type[QuantizationConfig]] = {
"awq": AWQConfig, "awq": AWQConfig,
"deepspeedfp": DeepSpeedFPConfig, "deepspeedfp": DeepSpeedFPConfig,
"tpu_int8": Int8TpuConfig,
"fp8": Fp8Config, "fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config, "fbgemm_fp8": FBGEMMFp8Config,
"fp_quant": FPQuantConfig, "fp_quant": FPQuantConfig,
......
...@@ -19,9 +19,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer ...@@ -19,9 +19,6 @@ from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKer
from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import (
TritonScaledMMLinearKernel, TritonScaledMMLinearKernel,
) )
from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import (
XLAScaledMMLinearKernel,
)
from vllm.platforms import PlatformEnum, current_platform from vllm.platforms import PlatformEnum, current_platform
# in priority/performance order (when available) # in priority/performance order (when available)
...@@ -29,7 +26,6 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = { ...@@ -29,7 +26,6 @@ _POSSIBLE_KERNELS: dict[PlatformEnum, list[type[ScaledMMLinearKernel]]] = {
PlatformEnum.CPU: [CPUScaledMMLinearKernel], PlatformEnum.CPU: [CPUScaledMMLinearKernel],
PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.CUDA: [CutlassScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel], PlatformEnum.ROCM: [AiterScaledMMLinearKernel, TritonScaledMMLinearKernel],
PlatformEnum.TPU: [XLAScaledMMLinearKernel],
} }
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import warnings
import torch
from functorch.experimental.control_flow import cond # noqa: F401
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
)
from vllm.platforms import current_platform
from .ScaledMMLinearKernel import ScaledMMLinearKernel, ScaledMMLinearLayerConfig
class XLAScaledMMLinearKernel(ScaledMMLinearKernel):
@classmethod
def is_supported(
cls, compute_capability: int | None = None
) -> tuple[bool, str | None]:
if not current_platform.is_tpu():
return False, "Requires TPU."
return True, None
@classmethod
def can_implement(cls, c: ScaledMMLinearLayerConfig) -> tuple[bool, str | None]:
if not current_platform.is_tpu():
return False, "ScaledMMXLA requires running on TPU."
if c.is_static_input_scheme:
return False, "ScaledMMXLA requires dynamic activation scales."
if not c.input_symmetric:
return False, "ScaledMMXLA requires symmetric activation scales."
if not c.is_channelwise:
return False, "ScaledMMXLA requires channelwise weight scales"
return True, None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# WEIGHT
# [out, in] (different than cutlass_scaled_mm)
weight = getattr(layer, self.w_q_name)
replace_parameter(
layer, self.w_q_name, torch.nn.Parameter(weight.data, requires_grad=False)
)
# WEIGHT SCALE
# XLA kernels support only per-tensor and per-channel.
# If we have a fused module (QKV, MLP) with per tensor scales (thus N
# scales being passed to the kernel), convert to the per-channel case.
is_fused_module = len(layer.logical_widths) > 1
weight_scale = getattr(layer, self.w_s_name)
if is_fused_module and not self.config.is_channelwise:
weight_scale = convert_to_channelwise(weight_scale, layer.logical_widths)
# [out_channel,] (different than cutlass_scaled_mm)
weight_scale = weight_scale.squeeze(-1)
replace_parameter(
layer,
self.w_s_name,
torch.nn.Parameter(weight_scale.data, requires_grad=False),
)
# Only support symmetric dynamic activation quantization.
setattr(layer, self.i_s_name, None)
setattr(layer, self.i_zp_name, None)
setattr(layer, self.azp_adj_name, None)
# Filter warning for cond usage in apply_weights. It is okay
# to specialize the graph since bias is not dynamic.
warnings.filterwarnings(
"ignore",
message="Pred is a Python constant. When used with torch.cond, it specializes on one of the branches.", # noqa: E501
)
def no_add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
return x
def add_bias(self, x: torch.Tensor, bias: torch.Tensor | None):
return x + bias
def apply_weights(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
w_q, w_s, _, _, _ = self._get_weight_params(layer)
# Required to register custom ops.
import torch_xla.experimental.custom_kernel # noqa: F401
out = torch.ops.xla.quantized_matmul_int8(
x,
w_q,
w_s,
quantize_activation=True,
)
# Explicitly capture control flow to make dynamo happy.
# https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501
return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias])
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
import torch
from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import (
QuantizationConfig,
QuantizationMethods,
)
from vllm.model_executor.parameter import ModelWeightParameter
ACTIVATION_SCHEMES = ["none", "dynamic"]
class Int8TpuConfig(QuantizationConfig):
"""Int8 Quantization Config class for TPU Backend."""
def __init__(
self,
activation_scheme: str = "none",
) -> None:
super().__init__()
if activation_scheme not in ACTIVATION_SCHEMES:
raise ValueError(f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
def get_name(self) -> QuantizationMethods:
return "tpu_int8"
def get_supported_act_dtypes(self) -> list[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
raise NotImplementedError("This function should not be called with TPU Backend")
@staticmethod
def get_config_filenames() -> list[str]:
return []
@classmethod
def from_config(cls, config: dict[str, Any]) -> "Int8TpuConfig":
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
return cls(activation_scheme=activation_scheme)
def get_quant_method(
self, layer: Module, prefix: str
) -> Optional["TPUInt8LinearMethod"]:
if isinstance(layer, LinearBase):
return TPUInt8LinearMethod(self)
return None
class TPUInt8LinearMethod(LinearMethodBase):
"""Int8 Linear method for TPU Quant."""
def __init__(self, quant_config: Int8TpuConfig):
self.quant_config = quant_config
self.quantize_activation = False
if self.quant_config.activation_scheme == "dynamic":
self.quantize_activation = True
def create_weights(
self,
layer: Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs.get("weight_loader")
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("weight", weight)
def _quantize_weight(
self, weight: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
weight_dtype = weight.dtype
weight = weight.cpu().to(torch.float32)
n_bit = 8
eps = 1e-5
max_int = 2 ** (n_bit - 1) - 1
min_int = -(2 ** (n_bit - 1))
max_val = weight.abs().amax(dim=-1, keepdim=True)
max_val = max_val.clamp(min=eps)
qscale = max_val / max_int
qweight = torch.clamp(
torch.round(weight * (1.0 / qscale)), min_int, max_int
).to(torch.int8)
qscale = qscale.squeeze().to(weight_dtype)
return qweight, qscale
def process_weights_after_loading(self, layer: Module) -> None:
layer.weight = Parameter(layer.weight.data, requires_grad=False)
device = layer.weight.device
qweight, qscale = self._quantize_weight(layer.weight)
qweight = qweight.to(device)
qscale = qscale.to(device)
layer.weight = Parameter(qweight, requires_grad=False)
layer.scale = Parameter(qscale, requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
try:
import torch_xla.experimental.custom_kernel # noqa: F401
except ImportError as err:
raise ImportError(
"Please install torch_xla by following the instructions at "
"https://docs.vllm.ai/en/latest/getting_started/tpu-installation.html " # noqa: E501
"to run vLLM on TPU."
) from err
weight = layer.weight
scale = layer.scale
out = torch.ops.xla.quantized_matmul_int8(
x, weight, scale, quantize_activation=self.quantize_activation
)
if bias is not None:
out = out + bias
return out
...@@ -30,7 +30,6 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -30,7 +30,6 @@ from vllm.model_executor.model_loader.weight_utils import (
pt_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator, safetensors_weights_iterator,
) )
from vllm.platforms import current_platform
from vllm.transformers_utils.repo_utils import list_filtered_repo_files from vllm.transformers_utils.repo_utils import list_filtered_repo_files
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -241,22 +240,6 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -241,22 +240,6 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config.pt_load_map_location, self.load_config.pt_load_map_location,
) )
if current_platform.is_tpu():
from vllm.platforms.tpu import USE_TPU_INFERENCE
if not USE_TPU_INFERENCE:
# In PyTorch XLA, we should call `torch_xla.sync`
# frequently so that not too many ops are accumulated
# in the XLA program.
import torch_xla
def _xla_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
torch_xla.sync(wait=False)
weights_iterator = _xla_weights_iterator(weights_iterator)
if self.counter_before_loading_weights == 0.0: if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter() self.counter_before_loading_weights = time.perf_counter()
# Apply the prefix. # Apply the prefix.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment