"vscode:/vscode.git/clone" did not exist on "c0f5fae601cf2649dec3cb06ad80008ced7a46ea"
Commit 7151fd54 authored by zhuwenwen's avatar zhuwenwen
Browse files

Triton-fused DeepseekScalingRotaryEmbedding

parent 8da1c576
......@@ -7,7 +7,8 @@ Tests for miscellaneous utilities
import pytest
import torch
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm.platforms import current_platform
......@@ -66,3 +67,61 @@ def test_rotary_embedding_opcheck(max_position, is_neox_style, rotary_dim,
ref_query,
atol=1e-2,
rtol=1e-2)
def test_deepseek_rotary_embedding():
device = torch.device("cuda:0")
current_platform.seed_everything(0)
torch.set_default_device("cuda:0")
batch_size = 10
base = 10000
num_heads = 8
max_position = 4096
is_neox_style = False
rotary_dim = 32
head_size = 64
scaling_factor = 40.0
rot = DeepseekScalingRotaryEmbedding(head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
torch.float32,
reference=False).to(device)
rot_ref = DeepseekScalingRotaryEmbedding(head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
torch.float32,
reference=True).to(device)
positions = torch.randint(0, max_position, (batch_size, ), device=device)
# query is [batch, num_heads, head_size]
# key is [batch, 1, head_size]
# cos_sin is [batch, head_size]
query = torch.randn(batch_size,
num_heads,
head_size,
dtype=torch.float32,
device=device)
key = torch.randn(batch_size,
1,
head_size,
dtype=torch.float32,
device=device)
ref_query, ref_key = rot_ref.forward(positions, query, key)
out_query, out_key = rot.forward(positions, query, key)
torch.testing.assert_close(out_key.cpu(),
ref_key.cpu(),
atol=1e-4,
rtol=1e-4)
torch.testing.assert_close(out_query.cpu(),
ref_query.cpu(),
atol=1e-4,
rtol=1e-4)
......@@ -30,6 +30,9 @@ from typing import Any, Optional, Union
import numpy as np
import torch
import torch.nn as nn
import triton
import triton.language as tl
from transformers import PretrainedConfig
from vllm.model_executor.custom_op import CustomOp
......@@ -796,6 +799,34 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return 0.1 * mscale * math.log(scale) + 1.0
@triton.jit
def deepseek_scaling_rotary_emb_kernel_gptj(cos_sin, q, stride1: int,
stride2: int, stride_cs: int,
dim1: int, dim2: int, dim3: int,
BLOCK_SIZE: tl.constexpr):
pid0 = tl.program_id(0)
pid1 = tl.program_id(1)
pid2 = tl.program_id(2)
offsets_cs = tl.arange(0, BLOCK_SIZE) + pid2 * BLOCK_SIZE
offsets_q = tl.arange(0, BLOCK_SIZE * 2) + pid2 * BLOCK_SIZE * 2
offsets = pid0 * stride1 + pid1 * stride2 + offsets_q
mask = offsets_cs < dim3
mask2 = offsets_q < dim3 * 2
v_cos = tl.load(cos_sin + pid0 * stride_cs + offsets_cs, mask=mask)
v_cos2 = tl.interleave(v_cos, v_cos)
v_sin = tl.load(cos_sin + pid0 * stride_cs + dim3 + offsets_cs, mask=mask)
v_sin2 = tl.interleave(v_sin, v_sin)
x12 = tl.load(q + offsets, mask=mask2)
x1, x2 = tl.split(x12.reshape([BLOCK_SIZE, 2]))
# we are both reading and writing 'q'; make sure all warps are in sync
tl.debug_barrier()
x12_ = tl.ravel(tl.join(-x2, x1))
x12 = x12 * v_cos2 + x12_ * v_sin2
tl.store(q + offsets, x12, mask=mask2)
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
"""RotaryEmbedding extended with YaRN method.
......@@ -818,12 +849,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
reference: bool = False,
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow
self.reference = reference
# Get n-d magnitude scaling corrected for interpolation.
self.mscale = float(
yarn_get_mscale(self.scaling_factor, float(mscale)) /
......@@ -874,30 +907,59 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""PyTorch-native implementation equivalent to forward()."""
assert key is not None
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
if self.cos_sin_cache.device != positions.device:
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
positions.device)
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
if offsets is not None else positions]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
if query.device.type == 'cuda' and not self.is_neox_style \
and not self.reference:
assert len(query.shape) == 3
def call(q):
BLOCK_SIZE = 64
grid = (
q.shape[-3],
q.shape[-2],
triton.cdiv(self.rotary_dim // 2, BLOCK_SIZE),
)
deepseek_scaling_rotary_emb_kernel_gptj[grid](
cos_sin,
q,
stride1=q.stride()[-3],
stride2=q.stride()[-2],
stride_cs=cos_sin.stride()[-2],
dim1=q.shape[0],
dim2=q.shape[1],
dim3=self.rotary_dim // 2,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=1)
call(query)
call(key)
return query, key
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
query_rot = query[..., :self.rotary_dim]
key_rot = key[..., :self.rotary_dim]
if self.rotary_dim < self.head_size:
query_pass = query[..., self.rotary_dim:]
key_pass = key[..., self.rotary_dim:]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj
query_rot = query_rot * cos + rotate_fn(query_rot) * sin
key_rot = key_rot * cos + rotate_fn(key_rot) * sin
if self.rotary_dim < self.head_size:
query = torch.cat((query_rot, query_pass), dim=-1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment