Commit 7151fd54 authored by zhuwenwen's avatar zhuwenwen
Browse files

Triton-fused DeepseekScalingRotaryEmbedding

parent 8da1c576
...@@ -7,7 +7,8 @@ Tests for miscellaneous utilities ...@@ -7,7 +7,8 @@ Tests for miscellaneous utilities
import pytest import pytest
import torch import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm.platforms import current_platform from vllm.platforms import current_platform
...@@ -66,3 +67,61 @@ def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim, ...@@ -66,3 +67,61 @@ def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
ref_query, ref_query,
atol=1e-2, atol=1e-2,
rtol=1e-2) rtol=1e-2)
def test_deepseek_rotary_embedding():
device = torch.device("cuda:0")
current_platform.seed_everything(0)
torch.set_default_device("cuda:0")
batch_size = 10
base = 10000
num_heads = 8
max_position = 4096
is_neox_style = False
rotary_dim = 32
head_size = 64
scaling_factor = 40.0
rot = DeepseekScalingRotaryEmbedding(head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
torch.float32,
reference=False).to(device)
rot_ref = DeepseekScalingRotaryEmbedding(head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
torch.float32,
reference=True).to(device)
positions = torch.randint(0, max_position, (batch_size, ), device=device)
# query is [batch, num_heads, head_size]
# key is [batch, 1, head_size]
# cos_sin is [batch, head_size]
query = torch.randn(batch_size,
num_heads,
head_size,
dtype=torch.float32,
device=device)
key = torch.randn(batch_size,
1,
head_size,
dtype=torch.float32,
device=device)
ref_query, ref_key = rot_ref.forward(positions, query, key)
out_query, out_key = rot.forward(positions, query, key)
torch.testing.assert_close(out_key.cpu(),
ref_key.cpu(),
atol=1e-4,
rtol=1e-4)
torch.testing.assert_close(out_query.cpu(),
ref_query.cpu(),
atol=1e-4,
rtol=1e-4)
...@@ -30,6 +30,9 @@ from typing import Any, Optional, Union ...@@ -30,6 +30,9 @@ from typing import Any, Optional, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import triton
import triton.language as tl
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
...@@ -796,6 +799,34 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: ...@@ -796,6 +799,34 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return 0.1 * mscale * math.log(scale) + 1.0 return 0.1 * mscale * math.log(scale) + 1.0
@triton.jit
def deepseek_scaling_rotary_emb_kernel_gptj(cos_sin, q, stride1: int,
stride2: int, stride_cs: int,
dim1: int, dim2: int, dim3: int,
BLOCK_SIZE: tl.constexpr):
pid0 = tl.program_id(0)
pid1 = tl.program_id(1)
pid2 = tl.program_id(2)
offsets_cs = tl.arange(0, BLOCK_SIZE) + pid2 * BLOCK_SIZE
offsets_q = tl.arange(0, BLOCK_SIZE * 2) + pid2 * BLOCK_SIZE * 2
offsets = pid0 * stride1 + pid1 * stride2 + offsets_q
mask = offsets_cs < dim3
mask2 = offsets_q < dim3 * 2
v_cos = tl.load(cos_sin + pid0 * stride_cs + offsets_cs, mask=mask)
v_cos2 = tl.interleave(v_cos, v_cos)
v_sin = tl.load(cos_sin + pid0 * stride_cs + dim3 + offsets_cs, mask=mask)
v_sin2 = tl.interleave(v_sin, v_sin)
x12 = tl.load(q + offsets, mask=mask2)
x1, x2 = tl.split(x12.reshape([BLOCK_SIZE, 2]))
# we are both reading and writing 'q'; make sure all warps are in sync
tl.debug_barrier()
x12_ = tl.ravel(tl.join(-x2, x1))
x12 = x12 * v_cos2 + x12_ * v_sin2
tl.store(q + offsets, x12, mask=mask2)
class DeepseekScalingRotaryEmbedding(RotaryEmbedding): class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with YaRN method. """RotaryEmbedding extended with YaRN method.
...@@ -818,12 +849,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -818,12 +849,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
beta_slow: int = 1, beta_slow: int = 1,
mscale: float = 1, mscale: float = 1,
mscale_all_dim: float = 0, mscale_all_dim: float = 0,
reference: bool = False,
) -> None: ) -> None:
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor self.attn_factor = attn_factor
self.beta_fast = beta_fast self.beta_fast = beta_fast
self.beta_slow = beta_slow self.beta_slow = beta_slow
self.reference = reference
# Get n-d magnitude scaling corrected for interpolation. # Get n-d magnitude scaling corrected for interpolation.
self.mscale = float( self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) / yarn_get_mscale(self.scaling_factor, float(mscale)) /
...@@ -874,30 +907,59 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -874,30 +907,59 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
) -> tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
assert key is not None assert key is not None
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
if self.cos_sin_cache.device != positions.device: if self.cos_sin_cache.device != positions.device:
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device) positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets) cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions] if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1) if query.device.type == 'cuda' and not self.is_neox_style \
if self.is_neox_style: and not self.reference:
# NOTE(woosuk): Here we assume that the positions tensor has the assert len(query.shape) == 3
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2) def call(q):
sin = sin.repeat(1, 1, 2).unsqueeze(-2) BLOCK_SIZE = 64
grid = (
q.shape[-3],
q.shape[-2],
triton.cdiv(self.rotary_dim // 2, BLOCK_SIZE),
)
deepseek_scaling_rotary_emb_kernel_gptj[grid](
cos_sin,
q,
stride1=q.stride()[-3],
stride2=q.stride()[-2],
stride_cs=cos_sin.stride()[-2],
dim1=q.shape[0],
dim2=q.shape[1],
dim3=self.rotary_dim // 2,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=1)
call(query)
call(key)
return query, key
else: else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) query_rot = query[..., :self.rotary_dim]
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if self.rotary_dim < self.head_size: if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1) query = torch.cat((query_rot, query_pass), dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment