Unverified Commit 9b294976 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Add PyTorch-native implementation of custom layers (#1898)

parent 5313c2cb
import pytest import pytest
import torch import torch
import torch.nn.functional as F
from transformers.activations import get_activation
from vllm._C import ops from vllm.model_executor.layers.activation import FastGELU, NewGELU, SiluAndMul
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing NUM_TOKENS = [7, 83, 2048] # Arbitrary values for testing
...@@ -11,11 +9,6 @@ D = [512, 4096, 5120, 13824] # Arbitrary values for testing ...@@ -11,11 +9,6 @@ D = [512, 4096, 5120, 13824] # Arbitrary values for testing
SEEDS = [0] SEEDS = [0]
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
x1, x2 = x.chunk(chunks=2, dim=1)
return F.silu(x1) * x2
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("d", D) @pytest.mark.parametrize("d", D)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
...@@ -30,9 +23,9 @@ def test_silu_and_mul( ...@@ -30,9 +23,9 @@ def test_silu_and_mul(
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda") x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") layer = SiluAndMul()
ops.silu_and_mul(out, x) out = layer(x)
ref_out = ref_silu_and_mul(x) ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
...@@ -50,9 +43,9 @@ def test_gelu_new( ...@@ -50,9 +43,9 @@ def test_gelu_new(
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") layer = NewGELU()
ops.gelu_new(out, x) out = layer(x)
ref_out = get_activation("gelu_new")(x) ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
...@@ -69,7 +62,7 @@ def test_gelu_fast( ...@@ -69,7 +62,7 @@ def test_gelu_fast(
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") x = torch.randn(num_tokens, d, dtype=dtype, device="cuda")
out = torch.empty(num_tokens, d, dtype=dtype, device="cuda") layer = FastGELU()
ops.gelu_fast(out, x) out = layer(x)
ref_out = get_activation("gelu_fast")(x) ref_out = layer._forward(x)
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5) assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
import pytest import pytest
import torch import torch
import torch.nn as nn
from vllm._C import ops from vllm.model_executor.layers.layernorm import RMSNorm
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
HIDDEN_SIZES = [67, 768, 2048, 5120, 8192] # Arbitrary values for testing
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
HIDDEN_SIZES = [768, 5120, 8192] # Arbitrary values for testing
ADD_RESIDUAL = [False, True]
SEEDS = [0] SEEDS = [0]
class RefRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
weight = torch.empty(hidden_size)
weight.normal_(mean=1.0, std=0.1)
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("num_tokens", NUM_TOKENS)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("add_residual", ADD_RESIDUAL)
@pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("seed", SEEDS)
@torch.inference_mode() @torch.inference_mode()
def test_rms_norm( def test_rms_norm(
num_tokens: int, num_tokens: int,
hidden_size: int, hidden_size: int,
add_residual: bool,
dtype: torch.dtype, dtype: torch.dtype,
seed: int, seed: int,
) -> None: ) -> None:
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
scale = float(hidden_size**-0.5) layer = RMSNorm(hidden_size).to(dtype).cuda()
x = torch.empty(num_tokens, hidden_size, dtype=dtype, device="cuda") layer.weight.data.normal_(mean=1.0, std=0.1)
x.uniform_(-scale, scale) scale = 1 / (2 * hidden_size)
ref = RefRMSNorm(hidden_size).to(dtype).cuda() x = torch.randn(num_tokens, hidden_size, dtype=dtype, device="cuda")
x *= scale
out = torch.empty_like(x) residual = torch.randn_like(x) * scale if add_residual else None
ops.rms_norm(
out, # NOTE(woosuk): The reference implementation should be executed first
x, # because the custom kernel is in-place.
ref.weight.data, ref_out = layer._forward(x, residual)
ref.variance_epsilon, out = layer(x, residual)
) # NOTE(woosuk): LayerNorm operators (including RMS) typically have larger
ref_out = ref(x) # numerical errors than other operators because they involve reductions.
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-5) # Therefore, we use a larger tolerance.
if add_residual:
assert torch.allclose(out[0], ref_out[0], atol=1e-2, rtol=1e-2)
assert torch.allclose(out[1], ref_out[1], atol=1e-2, rtol=1e-2)
else:
assert torch.allclose(out, ref_out, atol=1e-2, rtol=1e-2)
from typing import Optional, Tuple from typing import Optional
import pytest import pytest
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm._C import ops from vllm.model_executor.layers.rotary_embedding import get_rope
IS_NEOX_STYLE = [True, False] IS_NEOX_STYLE = [True, False]
DTYPES = [torch.half, torch.bfloat16, torch.float] DTYPES = [torch.half, torch.bfloat16, torch.float]
HEAD_SIZES = [64, 80, 96, 112, 128, 256] HEAD_SIZES = [64, 80, 96, 112, 128, 256]
ROTARY_DIMS = [None, 32] # None means rotary dim == head size ROTARY_DIMS = [None, 32] # None means rotary dim == head size
NUM_HEADS = [7, 12, 40, 52] # Arbitrary values for testing NUM_HEADS = [7, 17] # Arbitrary values for testing
NUM_TOKENS = [11, 83, 2048] # Arbitrary values for testing BATCH_SIZES = [1, 5] # Arbitrary values for testing
SEQ_LENS = [11, 8192] # Arbitrary values for testing
SEEDS = [0] SEEDS = [0]
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
def apply_rope(
q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
is_neox_style: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
rotate_fn = rotate_neox if is_neox_style else rotate_gptj
q_embed = (q * cos) + (rotate_fn(q) * sin)
k_embed = (k * cos) + (rotate_fn(k) * sin)
return q_embed, k_embed
class RefRotaryEmbedding(nn.Module):
"""Reference implementation of rotary embedding."""
def __init__(
self,
dim: int,
is_neox_style: bool,
max_position_embeddings: int = 8192,
base: int = 10000,
) -> None:
super().__init__()
self.rotary_dim = dim
self.is_neox_style = is_neox_style
self.max_position_embeddings = max_position_embeddings
# Create cos and sin embeddings.
inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim))
t = torch.arange(max_position_embeddings).float()
freqs = torch.einsum("i,j->ij", t, inv_freq.float())
if is_neox_style:
emb = torch.cat((freqs, freqs), dim=-1)
else:
emb = torch.repeat_interleave(freqs, 2, -1)
cos = emb.cos().to(dtype=inv_freq.dtype)
sin = emb.sin().to(dtype=inv_freq.dtype)
self.register_buffer("cos_cached", cos, persistent=False)
self.register_buffer("sin_cached", sin, persistent=False)
def forward(
self,
positions: torch.Tensor, # [num_tokens]
query: torch.Tensor, # [num_tokens, num_heads, head_size]
key: torch.Tensor, # [num_tokens, num_heads, head_size]
) -> Tuple[torch.Tensor, torch.Tensor]:
query_rot = query[..., :self.rotary_dim]
query_pass = query[..., self.rotary_dim:]
key_rot = key[..., :self.rotary_dim]
key_pass = key[..., self.rotary_dim:]
query_rot = query_rot.transpose(0, 1)
key_rot = key_rot.transpose(0, 1)
cos = F.embedding(positions, self.cos_cached)
sin = F.embedding(positions, self.sin_cached)
query_rot, key_rot = apply_rope(query_rot, key_rot, cos, sin,
self.is_neox_style)
query_rot = query_rot.transpose(0, 1).contiguous()
key_rot = key_rot.transpose(0, 1).contiguous()
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
# Output query/key shape: [num_tokens, num_tokens, head_size]
return query, key
@pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE)
@pytest.mark.parametrize("num_tokens", NUM_TOKENS) @pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("num_heads", NUM_HEADS) @pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES) @pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS)
...@@ -108,7 +26,8 @@ class RefRotaryEmbedding(nn.Module): ...@@ -108,7 +26,8 @@ class RefRotaryEmbedding(nn.Module):
@torch.inference_mode() @torch.inference_mode()
def test_rotary_embedding( def test_rotary_embedding(
is_neox_style: bool, is_neox_style: bool,
num_tokens: int, batch_size: int,
seq_len: int,
num_heads: int, num_heads: int,
head_size: int, head_size: int,
rotary_dim: Optional[int], rotary_dim: Optional[int],
...@@ -122,53 +41,25 @@ def test_rotary_embedding( ...@@ -122,53 +41,25 @@ def test_rotary_embedding(
torch.random.manual_seed(seed) torch.random.manual_seed(seed)
torch.cuda.manual_seed(seed) torch.cuda.manual_seed(seed)
positions = torch.randint(0, max_position, (num_tokens, ), device="cuda") if rotary_dim is None:
query = torch.randn(num_tokens, rotary_dim = head_size
rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style)
rope = rope.to(dtype).cuda()
positions = torch.randint(0,
max_position, (batch_size, seq_len),
device="cuda")
query = torch.randn(batch_size,
seq_len,
num_heads * head_size, num_heads * head_size,
dtype=dtype, dtype=dtype,
device="cuda") device="cuda")
key = torch.randn(num_tokens, key = torch.randn_like(query)
num_heads * head_size,
dtype=dtype,
device="cuda")
# Create the rotary embedding.
inv_freq = 1.0 / (base**(
torch.arange(0, rotary_dim, 2, dtype=torch.float) / rotary_dim))
t = torch.arange(max_position).float()
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cos_sin_cache = torch.cat((cos, sin), dim=-1)
cos_sin_cache = cos_sin_cache.to(dtype=dtype, device="cuda")
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query = query.clone()
out_key = key.clone()
ops.rotary_embedding(
positions,
out_query,
out_key,
head_size,
cos_sin_cache,
is_neox_style,
)
# Run the reference implementation.
ref_rotary_embedding = RefRotaryEmbedding(
dim=rotary_dim,
is_neox_style=is_neox_style,
max_position_embeddings=max_position,
base=base,
).to(dtype=dtype, device="cuda")
ref_query, ref_key = ref_rotary_embedding(
positions,
query.view(num_tokens, num_heads, head_size),
key.view(num_tokens, num_heads, head_size),
)
ref_query = ref_query.view(num_tokens, num_heads * head_size)
ref_key = ref_key.view(num_tokens, num_heads * head_size)
# NOTE(woosuk): The reference implementation should be executed first
# because the custom kernel is in-place.
ref_query, ref_key = rope._forward(positions, query, key)
out_query, out_key = rope.forward(positions, query, key)
# Compare the results. # Compare the results.
assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5) assert torch.allclose(out_query, ref_query, atol=1e-5, rtol=1e-5)
assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5) assert torch.allclose(out_key, ref_key, atol=1e-5, rtol=1e-5)
"""Custom activation functions.""" """Custom activation functions."""
import math
from typing import Optional from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from vllm._C import ops from vllm._C import ops
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
...@@ -22,6 +24,11 @@ class SiluAndMul(nn.Module): ...@@ -22,6 +24,11 @@ class SiluAndMul(nn.Module):
return: (batch_size, seq_len, d) or (num_tokens, d) return: (batch_size, seq_len, d) or (num_tokens, d)
""" """
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
output_shape = (x.shape[:-1] + (d, )) output_shape = (x.shape[:-1] + (d, ))
...@@ -32,6 +39,12 @@ class SiluAndMul(nn.Module): ...@@ -32,6 +39,12 @@ class SiluAndMul(nn.Module):
class NewGELU(nn.Module): class NewGELU(nn.Module):
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
c = math.sqrt(2.0 / math.pi)
return 0.5 * x * (1.0 + torch.tanh(c *
(x + 0.044715 * torch.pow(x, 3.0))))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x) out = torch.empty_like(x)
ops.gelu_new(out, x) ops.gelu_new(out, x)
...@@ -40,6 +53,11 @@ class NewGELU(nn.Module): ...@@ -40,6 +53,11 @@ class NewGELU(nn.Module):
class FastGELU(nn.Module): class FastGELU(nn.Module):
def _forward(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
(1.0 + 0.044715 * x * x)))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x) out = torch.empty_like(x)
ops.gelu_fast(out, x) ops.gelu_fast(out, x)
......
...@@ -23,6 +23,26 @@ class RMSNorm(nn.Module): ...@@ -23,6 +23,26 @@ class RMSNorm(nn.Module):
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
def _forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
orig_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + self.variance_epsilon)
x = x.to(orig_dtype) * self.weight
if residual is None:
return x
else:
return x, residual
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,
......
...@@ -30,6 +30,19 @@ import torch.nn as nn ...@@ -30,6 +30,19 @@ import torch.nn as nn
from vllm._C import ops from vllm._C import ops
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def _rotate_gptj(x: torch.Tensor) -> torch.Tensor:
x1 = x[..., ::2]
x2 = x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
return x.flatten(-2)
class RotaryEmbedding(nn.Module): class RotaryEmbedding(nn.Module):
"""Original rotary positional embedding.""" """Original rotary positional embedding."""
...@@ -81,6 +94,47 @@ class RotaryEmbedding(nn.Module): ...@@ -81,6 +94,47 @@ class RotaryEmbedding(nn.Module):
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
return cache return cache
def _forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""PyTorch-native implementation equivalent to forward()."""
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
cos_sin = self.cos_sin_cache[positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
key = torch.cat((key_rot, key_pass), dim=-1)
else:
query = query_rot
key = key_rot
query = query.flatten(-2)
key = key.flatten(-2)
return query, key
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment