Unverified Commit 41ca62cf authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[Misc] Add CustomOp interface for device portability (#5255)

parent 974fc9b8
...@@ -44,7 +44,7 @@ def test_act_and_mul( ...@@ -44,7 +44,7 @@ def test_act_and_mul(
elif activation == "gelu_tanh": elif activation == "gelu_tanh":
layer = GeluAndMul(approximate="tanh") layer = GeluAndMul(approximate="tanh")
out = layer(x) out = layer(x)
ref_out = layer._forward(x) ref_out = layer.forward_native(x)
# The SiLU and GELU implementations are equivalent to the native PyTorch # The SiLU and GELU implementations are equivalent to the native PyTorch
# implementations, so we can do exact comparison. # implementations, so we can do exact comparison.
assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0) assert torch.allclose(out, ref_out, atol=0.0, rtol=0.0)
...@@ -72,7 +72,7 @@ def test_activation( ...@@ -72,7 +72,7 @@ def test_activation(
x = torch.randn(num_tokens, d, dtype=dtype) x = torch.randn(num_tokens, d, dtype=dtype)
layer = activation() layer = activation()
out = layer(x) out = layer(x)
ref_out = layer._forward(x) ref_out = layer.forward_native(x)
assert torch.allclose(out, assert torch.allclose(out,
ref_out, ref_out,
atol=get_default_atol(out), atol=get_default_atol(out),
......
...@@ -42,7 +42,7 @@ def test_rms_norm( ...@@ -42,7 +42,7 @@ def test_rms_norm(
# NOTE(woosuk): The reference implementation should be executed first # NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place. # because the custom kernel is in-place.
ref_out = layer._forward(x, residual) ref_out = layer.forward_native(x, residual)
out = layer(x, residual) out = layer(x, residual)
# NOTE(woosuk): LayerNorm operators (including RMS) typically have larger # NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
# numerical errors than other operators because they involve reductions. # numerical errors than other operators because they involve reductions.
......
...@@ -64,7 +64,7 @@ def test_rotary_embedding( ...@@ -64,7 +64,7 @@ def test_rotary_embedding(
# NOTE(woosuk): The reference implementation should be executed first # NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place. # because the custom kernel is in-place.
ref_query, ref_key = rope._forward(positions, query, key) ref_query, ref_key = rope.forward_native(positions, query, key)
out_query, out_key = rope.forward(positions, query, key) out_query, out_key = rope.forward(positions, query, key)
# Compare the results. # Compare the results.
assert torch.allclose(out_query, assert torch.allclose(out_query,
...@@ -121,7 +121,7 @@ def test_batched_rotary_embedding( ...@@ -121,7 +121,7 @@ def test_batched_rotary_embedding(
# NOTE(woosuk): The reference implementation should be executed first # NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place. # because the custom kernel is in-place.
ref_query, ref_key = rope._forward(positions, query, key) ref_query, ref_key = rope.forward_native(positions, query, key)
out_query, out_key = rope.forward(positions, out_query, out_key = rope.forward(positions,
query, query,
key, key,
...@@ -195,7 +195,8 @@ def test_batched_rotary_embedding_multi_lora( ...@@ -195,7 +195,8 @@ def test_batched_rotary_embedding_multi_lora(
# NOTE(woosuk): The reference implementation should be executed first # NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place. # because the custom kernel is in-place.
ref_query, ref_key = rope._forward(positions, query, key, query_offsets) ref_query, ref_key = rope.forward_native(positions, query, key,
query_offsets)
out_query, out_key = rope.forward(positions, query, key, out_query, out_key = rope.forward(positions, query, key,
query_offsets.flatten()) query_offsets.flatten())
# Compare the results. # Compare the results.
......
import torch.nn as nn
from vllm.utils import is_cpu, is_hip
class CustomOp(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
self._forward_method = self.dispatch_forward()
def forward(self, *args, **kwargs):
return self._forward_method(*args, **kwargs)
def forward_native(self, *args, **kwargs):
"""PyTorch-native implementation of the forward method.
This method is optional. If implemented, it can be used with compilers
such as torch.compile or PyTorch XLA. Also, it can be used for testing
purposes.
"""
raise NotImplementedError
def forward_cuda(self, *args, **kwargs):
raise NotImplementedError
def forward_hip(self, *args, **kwargs):
# By default, we assume that HIP ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)
def forward_xpu(self, *args, **kwargs):
# By default, we assume that XPU ops are compatible with CUDA ops.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_cuda(*args, **kwargs)
def forward_cpu(self, *args, **kwargs):
# By default, we assume that CPU ops are compatible with CUDA ops.
return self.forward_cuda(*args, **kwargs)
def forward_tpu(self, *args, **kwargs):
# By default, we assume that TPU ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_native(*args, **kwargs)
def forward_gaudi(self, *args, **kwargs):
# By default, we assume that Gaudi ops are compatible with the
# PyTorch-native implementation.
# NOTE(woosuk): This is a placeholder for future extensions.
return self.forward_native(*args, **kwargs)
def dispatch_forward(self):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
if is_hip():
return self.forward_hip
elif is_cpu():
return self.forward_cpu
else:
return self.forward_cuda
...@@ -6,14 +6,14 @@ import torch ...@@ -6,14 +6,14 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm import _custom_ops as ops
from vllm.distributed import (divide, get_tensor_model_parallel_rank, from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
class SiluAndMul(nn.Module): class SiluAndMul(CustomOp):
"""An activation function for SwiGLU. """An activation function for SwiGLU.
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
...@@ -23,12 +23,14 @@ class SiluAndMul(nn.Module): ...@@ -23,12 +23,14 @@ class SiluAndMul(nn.Module):
return: (num_tokens, d) or (batch_size, seq_len, d) return: (num_tokens, d) or (batch_size, seq_len, d)
""" """
def _forward(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2 d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:] return F.silu(x[..., :d]) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
d = x.shape[-1] // 2 d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, )) output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
...@@ -36,7 +38,7 @@ class SiluAndMul(nn.Module): ...@@ -36,7 +38,7 @@ class SiluAndMul(nn.Module):
return out return out
class GeluAndMul(nn.Module): class GeluAndMul(CustomOp):
"""An activation function for GeGLU. """An activation function for GeGLU.
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2. The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
...@@ -52,12 +54,14 @@ class GeluAndMul(nn.Module): ...@@ -52,12 +54,14 @@ class GeluAndMul(nn.Module):
if approximate not in ("none", "tanh"): if approximate not in ("none", "tanh"):
raise ValueError(f"Unknown approximate mode: {approximate}") raise ValueError(f"Unknown approximate mode: {approximate}")
def _forward(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2 d = x.shape[-1] // 2
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:] return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
d = x.shape[-1] // 2 d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, )) output_shape = (x.shape[:-1] + (d, ))
out = torch.empty(output_shape, dtype=x.dtype, device=x.device) out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
...@@ -71,28 +75,32 @@ class GeluAndMul(nn.Module): ...@@ -71,28 +75,32 @@ class GeluAndMul(nn.Module):
return f'approximate={repr(self.approximate)}' return f'approximate={repr(self.approximate)}'
class NewGELU(nn.Module): class NewGELU(CustomOp):
def _forward(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
c = math.sqrt(2.0 / math.pi) c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c * return 0.5 * x * (1.0 + torch.tanh(c *
(x + 0.044715 * torch.pow(x, 3.0)))) (x + 0.044715 * torch.pow(x, 3.0))))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x) out = torch.empty_like(x)
ops.gelu_new(out, x) ops.gelu_new(out, x)
return out return out
class FastGELU(nn.Module): class FastGELU(CustomOp):
def _forward(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
(1.0 + 0.044715 * x * x))) (1.0 + 0.044715 * x * x)))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x) out = torch.empty_like(x)
ops.gelu_fast(out, x) ops.gelu_fast(out, x)
return out return out
......
...@@ -4,10 +4,10 @@ from typing import Optional, Tuple, Union ...@@ -4,10 +4,10 @@ from typing import Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm import _custom_ops as ops from vllm.model_executor.custom_op import CustomOp
class RMSNorm(nn.Module): class RMSNorm(CustomOp):
"""Root mean square normalization. """Root mean square normalization.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
...@@ -23,7 +23,7 @@ class RMSNorm(nn.Module): ...@@ -23,7 +23,7 @@ class RMSNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
def _forward( def forward_native(
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
...@@ -43,11 +43,13 @@ class RMSNorm(nn.Module): ...@@ -43,11 +43,13 @@ class RMSNorm(nn.Module):
else: else:
return x, residual return x, residual
def forward( def forward_cuda(
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
from vllm import _custom_ops as ops
if residual is not None: if residual is not None:
ops.fused_add_rms_norm( ops.fused_add_rms_norm(
x, x,
......
...@@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union ...@@ -27,7 +27,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm import _custom_ops as ops from vllm.model_executor.custom_op import CustomOp
def _rotate_neox(x: torch.Tensor) -> torch.Tensor: def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
...@@ -43,7 +43,7 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: ...@@ -43,7 +43,7 @@ def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
return x.flatten(-2) return x.flatten(-2)
class RotaryEmbedding(nn.Module): class RotaryEmbedding(CustomOp):
"""Original rotary positional embedding.""" """Original rotary positional embedding."""
def __init__( def __init__(
...@@ -93,7 +93,7 @@ class RotaryEmbedding(nn.Module): ...@@ -93,7 +93,7 @@ class RotaryEmbedding(nn.Module):
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
return cache return cache
def _forward( def forward_native(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
...@@ -138,13 +138,15 @@ class RotaryEmbedding(nn.Module): ...@@ -138,13 +138,15 @@ class RotaryEmbedding(nn.Module):
key = key.flatten(-2) key = key.flatten(-2)
return query, key return query, key
def forward( def forward_cuda(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
offsets: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device, self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
dtype=query.dtype) dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding() # ops.rotary_embedding()/batched_rotary_embedding()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment