"tests/vscode:/vscode.git/clone" did not exist on "5412a3341fa5d0211629ee87899015f98a62e0cc"
Unverified Commit e1eae1fd authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support MLA for DeepSeek-V2 with Triton - step 1 (#905)

parent f4d9953d
File mode changed from 100644 to 100755
...@@ -57,6 +57,8 @@ def _fwd_kernel( ...@@ -57,6 +57,8 @@ def _fwd_kernel(
stride_buf_vh, stride_buf_vh,
stride_req_to_tokens_b, stride_req_to_tokens_b,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr, logit_cap: tl.constexpr,
...@@ -75,8 +77,10 @@ def _fwd_kernel( ...@@ -75,8 +77,10 @@ def _fwd_kernel(
cur_batch_req_idx = tl.load(B_req_idx + cur_seq) cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
offs_d = tl.arange(0, BLOCK_DMODEL) offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
offs_m = tl.arange(0, BLOCK_M) offs_m = tl.arange(0, BLOCK_M)
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
offs_q = ( offs_q = (
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs * stride_qbs
...@@ -85,10 +89,20 @@ def _fwd_kernel( ...@@ -85,10 +89,20 @@ def _fwd_kernel(
) )
q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0) q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
offs_qpe = (
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs
+ cur_head * stride_qh
+ offs_dpe[None, :]
)
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
# stage1: compute scores with prefix # stage1: compute scores with prefix
offs_n = tl.arange(0, BLOCK_N) offs_n = tl.arange(0, BLOCK_N)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
deno = tl.zeros([BLOCK_M], dtype=tl.float32) deno = tl.zeros([BLOCK_M], dtype=tl.float32)
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
...@@ -110,6 +124,18 @@ def _fwd_kernel( ...@@ -110,6 +124,18 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k) qk += tl.dot(q, k)
if BLOCK_DPE > 0:
offs_kpe = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Buffer + offs_kpe,
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe, kpe)
qk *= sm_scale qk *= sm_scale
if logit_cap > 0: if logit_cap > 0:
...@@ -125,7 +151,7 @@ def _fwd_kernel( ...@@ -125,7 +151,7 @@ def _fwd_kernel(
offs_buf_v = ( offs_buf_v = (
offs_kv_loc[:, None] * stride_buf_vbs offs_kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh + cur_kv_head * stride_buf_vh
+ offs_d[None, :] + offs_dv[None, :]
) )
v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0) v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
p = p.to(v.dtype) p = p.to(v.dtype)
...@@ -150,6 +176,21 @@ def _fwd_kernel( ...@@ -150,6 +176,21 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k) qk += tl.dot(q, k)
if BLOCK_DPE > 0:
offs_kpe = (
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
* stride_kbs
+ cur_kv_head * stride_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Extend + offs_kpe,
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe, kpe)
qk *= sm_scale qk *= sm_scale
if logit_cap > 0: if logit_cap > 0:
...@@ -169,7 +210,7 @@ def _fwd_kernel( ...@@ -169,7 +210,7 @@ def _fwd_kernel(
offs_v = ( offs_v = (
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
+ cur_kv_head * stride_vh + cur_kv_head * stride_vh
+ offs_d[None, :] + offs_dv[None, :]
) )
v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0) v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
p = p.to(v.dtype) p = p.to(v.dtype)
...@@ -181,7 +222,7 @@ def _fwd_kernel( ...@@ -181,7 +222,7 @@ def _fwd_kernel(
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_obs * stride_obs
+ cur_head * stride_oh + cur_head * stride_oh
+ offs_d[None, :] + offs_dv[None, :]
) )
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None]) tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
...@@ -217,8 +258,17 @@ def extend_attention_fwd( ...@@ -217,8 +258,17 @@ def extend_attention_fwd(
o_extend.shape[-1], o_extend.shape[-1],
) )
assert Lq == Lk and Lk == Lv and Lv == Lo assert Lq == Lk and Lv == Lo
assert Lq in {16, 32, 64, 128, 256} assert Lq in {16, 32, 64, 128, 256, 576}
assert Lv in {16, 32, 64, 128, 256, 512}
if Lq == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
else:
BLOCK_DMODEL = Lq
BLOCK_DPE = 0
BLOCK_DV = Lv
if CUDA_CAPABILITY[0] >= 8: if CUDA_CAPABILITY[0] >= 8:
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64) BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
...@@ -260,7 +310,9 @@ def extend_attention_fwd( ...@@ -260,7 +310,9 @@ def extend_attention_fwd(
v_buffer.stride(0), v_buffer.stride(0),
v_buffer.stride(1), v_buffer.stride(1),
req_to_tokens.stride(0), req_to_tokens.stride(0),
BLOCK_DMODEL=Lq, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
BLOCK_M=BLOCK_M, BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N, BLOCK_N=BLOCK_N,
num_warps=num_warps, num_warps=num_warps,
......
...@@ -38,16 +38,22 @@ class RadixAttention(nn.Module): ...@@ -38,16 +38,22 @@ class RadixAttention(nn.Module):
num_kv_heads: int, num_kv_heads: int,
layer_id: int, layer_id: int,
logit_cap: int = -1, logit_cap: int = -1,
v_head_dim: int = -1,
): ):
super().__init__() super().__init__()
self.tp_q_head_num = num_heads self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads self.tp_k_head_num = num_kv_heads
self.tp_v_head_num = num_kv_heads self.tp_v_head_num = num_kv_heads
self.head_dim = head_dim self.head_dim = head_dim
self.qk_head_dim = head_dim
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
self.scaling = scaling self.scaling = scaling
self.layer_id = layer_id self.layer_id = layer_id
if not global_server_args_dict.get("disable_flashinfer", False): if (
not global_server_args_dict.get("disable_flashinfer", False)
and self.qk_head_dim == self.v_head_dim
):
self.extend_forward = self.extend_forward_flashinfer self.extend_forward = self.extend_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer self.decode_forward = self.decode_forward_flashinfer
else: else:
...@@ -57,13 +63,17 @@ class RadixAttention(nn.Module): ...@@ -57,13 +63,17 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0 self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
if self.qk_head_dim != self.v_head_dim:
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
else:
o = torch.empty_like(q) o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata) self.store_kv_cache(k, v, input_metadata)
extend_attention_fwd( extend_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim), q.view(-1, self.tp_q_head_num, self.qk_head_dim),
k.contiguous(), k.contiguous(),
v.contiguous(), v.contiguous(),
o.view(-1, self.tp_q_head_num, self.head_dim), o.view(-1, self.tp_q_head_num, self.v_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id), input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
input_metadata.req_to_token_pool.req_to_token, input_metadata.req_to_token_pool.req_to_token,
...@@ -82,14 +92,17 @@ class RadixAttention(nn.Module): ...@@ -82,14 +92,17 @@ class RadixAttention(nn.Module):
return o return o
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
if self.qk_head_dim != self.v_head_dim:
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
else:
o = torch.empty_like(q) o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata) self.store_kv_cache(k, v, input_metadata)
token_attention_fwd( token_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim), q.view(-1, self.tp_q_head_num, self.qk_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id), input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
o.view(-1, self.tp_q_head_num, self.head_dim), o.view(-1, self.tp_q_head_num, self.v_head_dim),
input_metadata.req_to_token_pool.req_to_token, input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices, input_metadata.req_pool_indices,
input_metadata.triton_start_loc, input_metadata.triton_start_loc,
...@@ -160,8 +173,8 @@ class RadixAttention(nn.Module): ...@@ -160,8 +173,8 @@ class RadixAttention(nn.Module):
return o.view(-1, self.tp_q_head_num * self.head_dim) return o.view(-1, self.tp_q_head_num * self.head_dim)
def forward(self, q, k, v, input_metadata: InputMetadata): def forward(self, q, k, v, input_metadata: InputMetadata):
k = k.view(-1, self.tp_k_head_num, self.head_dim) k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.head_dim) v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
if input_metadata.forward_mode == ForwardMode.EXTEND: if input_metadata.forward_mode == ForwardMode.EXTEND:
return self.extend_forward(q, k, v, input_metadata) return self.extend_forward(q, k, v, input_metadata)
......
...@@ -54,6 +54,7 @@ def _fwd_kernel_stage1( ...@@ -54,6 +54,7 @@ def _fwd_kernel_stage1(
att_stride_h, att_stride_h,
kv_group_num: tl.constexpr, kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_N: tl.constexpr, BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr, logit_cap: tl.constexpr,
): ):
...@@ -73,6 +74,10 @@ def _fwd_kernel_stage1( ...@@ -73,6 +74,10 @@ def _fwd_kernel_stage1(
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
block_stard_index = start_n * BLOCK_N block_stard_index = start_n * BLOCK_N
...@@ -97,6 +102,19 @@ def _fwd_kernel_stage1( ...@@ -97,6 +102,19 @@ def _fwd_kernel_stage1(
other=0.0, other=0.0,
).to(REDUCE_TRITON_TYPE) ).to(REDUCE_TRITON_TYPE)
att_value = tl.sum(q[None, :] * k, 1) att_value = tl.sum(q[None, :] * k, 1)
if BLOCK_DPE > 0:
qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE)
offs_buf_kpe = (
k_loc[:, None] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[None, :]
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=offs_n_new[:, None] < cur_batch_end_index,
other=0.0,
).to(REDUCE_TRITON_TYPE)
att_value += tl.sum(qpe[None, :] * kpe, 1)
att_value *= sm_scale att_value *= sm_scale
if logit_cap > 0: if logit_cap > 0:
...@@ -192,7 +210,14 @@ def _token_att_m_fwd( ...@@ -192,7 +210,14 @@ def _token_att_m_fwd(
# shape constraints # shape constraints
Lq, Lk = q.shape[-1], k_buffer.shape[-1] Lq, Lk = q.shape[-1], k_buffer.shape[-1]
assert Lq == Lk assert Lq == Lk
assert Lk in {16, 32, 64, 128, 256} assert Lk in {16, 32, 64, 128, 256, 576}
if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
else:
BLOCK_DMODEL = Lk
BLOCK_DPE = 0
batch, head_num = B_req_idx.shape[0], q.shape[1] batch, head_num = B_req_idx.shape[0], q.shape[1]
...@@ -220,7 +245,8 @@ def _token_att_m_fwd( ...@@ -220,7 +245,8 @@ def _token_att_m_fwd(
k_buffer.stride(1), k_buffer.stride(1),
att_out.stride(0), att_out.stride(0),
kv_group_num=kv_group_num, kv_group_num=kv_group_num,
BLOCK_DMODEL=Lk, BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_N=BLOCK, BLOCK_N=BLOCK,
logit_cap=logit_cap, logit_cap=logit_cap,
num_warps=num_warps, num_warps=num_warps,
......
...@@ -29,7 +29,7 @@ from sglang.global_config import global_config ...@@ -29,7 +29,7 @@ from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.mem_cache.radix_cache import RadixCache
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
...@@ -39,6 +39,7 @@ global_server_args_dict = { ...@@ -39,6 +39,7 @@ global_server_args_dict = {
"disable_flashinfer": False, "disable_flashinfer": False,
"disable_flashinfer_sampling": False, "disable_flashinfer_sampling": False,
"attention_reduce_in_fp32": False, "attention_reduce_in_fp32": False,
"enable_mla": False,
} }
...@@ -289,7 +290,7 @@ class Batch: ...@@ -289,7 +290,7 @@ class Batch:
# Request, memory pool, and cache # Request, memory pool, and cache
reqs: List[Req] reqs: List[Req]
req_to_token_pool: ReqToTokenPool req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool token_to_kv_pool: BaseTokenToKVPool
tree_cache: RadixCache tree_cache: RadixCache
# Batched arguments to model runner # Batched arguments to model runner
...@@ -780,7 +781,7 @@ class InputMetadata: ...@@ -780,7 +781,7 @@ class InputMetadata:
seq_lens: torch.Tensor seq_lens: torch.Tensor
positions: torch.Tensor positions: torch.Tensor
req_to_token_pool: ReqToTokenPool req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool token_to_kv_pool: BaseTokenToKVPool
# For extend # For extend
extend_seq_lens: torch.Tensor extend_seq_lens: torch.Tensor
......
...@@ -57,32 +57,18 @@ class ReqToTokenPool: ...@@ -57,32 +57,18 @@ class ReqToTokenPool:
self.can_use_mem_size = len(self.mem_state) self.can_use_mem_size = len(self.mem_state)
class TokenToKVPool: class BaseTokenToKVPool:
"""A memory pool that maps a token to its kv cache locations""" """A memory pool that maps a token to its kv cache locations"""
def __init__( def __init__(
self, self,
size: int, size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
): ):
self.size = size self.size = size
# We also add one slot. This slot is used for writing dummy output from padded tokens. # We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda") self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
# Prefetch buffer # Prefetch buffer
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32) self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
self.prefetch_chunk_size = 512 self.prefetch_chunk_size = 512
...@@ -90,15 +76,6 @@ class TokenToKVPool: ...@@ -90,15 +76,6 @@ class TokenToKVPool:
self.can_use_mem_size = self.size self.can_use_mem_size = self.size
self.clear() self.clear()
def get_key_buffer(self, layer_id: int):
return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
return self.v_buffer[layer_id]
def get_kv_buffer(self, layer_id: int):
return self.k_buffer[layer_id], self.v_buffer[layer_id]
def available_size(self): def available_size(self):
return self.can_use_mem_size + len(self.prefetch_buffer) return self.can_use_mem_size + len(self.prefetch_buffer)
...@@ -139,3 +116,67 @@ class TokenToKVPool: ...@@ -139,3 +116,67 @@ class TokenToKVPool:
# We also add one slot. This slot is used for writing dummy output from padded tokens. # We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state[0] = False self.mem_state[0] = False
class MHATokenToKVPool(BaseTokenToKVPool):
def __init__(
self,
size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
):
super().__init__(size)
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id: int):
return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
return self.v_buffer[layer_id]
def get_kv_buffer(self, layer_id: int):
return self.k_buffer[layer_id], self.v_buffer[layer_id]
class MLATokenToKVPool(BaseTokenToKVPool):
def __init__(
self,
size: int,
dtype: torch.dtype,
kv_lora_rank: int,
qk_rope_head_dim: int,
layer_num: int,
):
super().__init__(size)
self.kv_lora_rank = kv_lora_rank
self.kv_buffer = [
torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
dtype=dtype,
device="cuda",
)
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id: int):
return self.kv_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
def get_kv_buffer(self, layer_id: int):
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
""" """
from enum import IntEnum, auto
from typing import Optional from typing import Optional
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -20,6 +21,11 @@ from transformers import PretrainedConfig ...@@ -20,6 +21,11 @@ from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length from sglang.srt.hf_transformers_utils import get_config, get_context_length
class AttentionArch(IntEnum):
MLA = auto()
MHA = auto()
class ModelConfig: class ModelConfig:
def __init__( def __init__(
self, self,
...@@ -55,6 +61,11 @@ class ModelConfig: ...@@ -55,6 +61,11 @@ class ModelConfig:
# FIXME: temporary special judge for deepseek v2 MLA architecture # FIXME: temporary special judge for deepseek v2 MLA architecture
if "DeepseekV2ForCausalLM" in self.hf_config.architectures: if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
self.head_dim = 256 self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
else:
self.attention_arch = AttentionArch.MHA
self.num_attention_heads = self.hf_config.num_attention_heads self.num_attention_heads = self.hf_config.num_attention_heads
self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None) self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
......
...@@ -47,7 +47,12 @@ from sglang.srt.managers.schedule_batch import ( ...@@ -47,7 +47,12 @@ from sglang.srt.managers.schedule_batch import (
InputMetadata, InputMetadata,
global_server_args_dict, global_server_args_dict,
) )
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.model_config import AttentionArch
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
...@@ -86,6 +91,7 @@ class ModelRunner: ...@@ -86,6 +91,7 @@ class ModelRunner:
"disable_flashinfer": server_args.disable_flashinfer, "disable_flashinfer": server_args.disable_flashinfer,
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling, "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32, "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
"enable_mla": server_args.enable_mla,
} }
) )
...@@ -193,11 +199,19 @@ class ModelRunner: ...@@ -193,11 +199,19 @@ class ModelRunner:
available_gpu_memory = get_available_gpu_memory( available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1 self.gpu_id, distributed=self.tp_size > 1
) )
head_dim = self.model_config.head_dim if (
head_num = self.model_config.get_num_kv_heads(self.tp_size) self.model_config.attention_arch == AttentionArch.MLA
and self.server_args.enable_mla
):
cell_size = (
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
* self.model_config.num_hidden_layers
* torch._utils._element_size(self.dtype)
)
else:
cell_size = ( cell_size = (
head_num self.model_config.get_num_kv_heads(self.tp_size)
* head_dim * self.model_config.head_dim
* self.model_config.num_hidden_layers * self.model_config.num_hidden_layers
* 2 * 2
* torch._utils._element_size(self.dtype) * torch._utils._element_size(self.dtype)
...@@ -241,7 +255,22 @@ class ModelRunner: ...@@ -241,7 +255,22 @@ class ModelRunner:
max_num_reqs, max_num_reqs,
self.model_config.context_len + 8, self.model_config.context_len + 8,
) )
self.token_to_kv_pool = TokenToKVPool( if (
self.model_config.attention_arch == AttentionArch.MLA
and self.server_args.enable_mla
):
self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens,
dtype=self.dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers,
)
logger.info("using MLA Triton implementaion, flashinfer is disabled")
# FIXME: temporarily only Triton MLA is supported
self.server_args.disable_flashinfer = True
else:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
dtype=self.dtype, dtype=self.dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size), head_num=self.model_config.get_num_kv_heads(self.tp_size),
......
...@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import InputMetadata
...@@ -312,6 +313,165 @@ class DeepseekV2Attention(nn.Module): ...@@ -312,6 +313,165 @@ class DeepseekV2Attention(nn.Module):
return output return output
class DeepseekV2AttentionMLA(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int,
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
layer_id=None,
) -> None:
super().__init__()
self.layer_id = layer_id
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
)
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
)
# O projection.
self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
)
rope_scaling["type"] = "deepseek_yarn"
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False,
)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.attn = RadixAttention(
self.num_local_heads,
self.kv_lora_rank + self.qk_rope_head_dim,
self.scaling,
num_kv_heads=1,
layer_id=layer_id,
v_head_dim=self.kv_lora_rank,
)
kv_b_proj = self.kv_b_proj
w_kc, w_vc = kv_b_proj.weight.unflatten(
0, (-1, qk_nope_head_dim + v_head_dim)
).split([qk_nope_head_dim, v_head_dim], dim=1)
self.w_kc = w_kc
self.w_vc = w_vc
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty(
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
)
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_nope_out = q_input[..., : self.kv_lora_rank]
torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1)
k_pe = k_input[..., self.kv_lora_rank :]
v_input = k_input[..., : self.kv_lora_rank]
v_input = self.kv_a_layernorm(v_input.contiguous())
k_input[..., : self.kv_lora_rank] = v_input
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q_input[..., self.kv_lora_rank :] = q_pe
k_input[..., self.kv_lora_rank :] = k_pe
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
attn_bmm_output = attn_output.new_empty(
q_len, self.num_local_heads, self.v_head_dim
)
torch.bmm(
attn_output.transpose(0, 1),
self.w_vc.transpose(1, 2).contiguous(),
out=attn_bmm_output.transpose(0, 1),
)
attn_output = attn_bmm_output.flatten(1, 2)
output, _ = self.o_proj(attn_output)
return output
class DeepseekV2DecoderLayer(nn.Module): class DeepseekV2DecoderLayer(nn.Module):
def __init__( def __init__(
...@@ -326,6 +486,26 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -326,6 +486,26 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
if global_server_args_dict["enable_mla"]:
self.self_attn = DeepseekV2AttentionMLA(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=(
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
),
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
layer_id=layer_id,
)
else:
self.self_attn = DeepseekV2Attention( self.self_attn = DeepseekV2Attention(
config=config, config=config,
hidden_size=self.hidden_size, hidden_size=self.hidden_size,
...@@ -333,7 +513,9 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -333,7 +513,9 @@ class DeepseekV2DecoderLayer(nn.Module):
qk_nope_head_dim=config.qk_nope_head_dim, qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim, qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim, v_head_dim=config.v_head_dim,
q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None, q_lora_rank=(
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
),
kv_lora_rank=config.kv_lora_rank, kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta, rope_theta=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
......
...@@ -80,6 +80,7 @@ class ServerArgs: ...@@ -80,6 +80,7 @@ class ServerArgs:
disable_disk_cache: bool = False disable_disk_cache: bool = False
enable_torch_compile: bool = False enable_torch_compile: bool = False
enable_p2p_check: bool = False enable_p2p_check: bool = False
enable_mla: bool = False
attention_reduce_in_fp32: bool = False attention_reduce_in_fp32: bool = False
efficient_weight_load: bool = False efficient_weight_load: bool = False
...@@ -393,6 +394,11 @@ class ServerArgs: ...@@ -393,6 +394,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.", help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
) )
parser.add_argument(
"--enable-mla",
action="store_true",
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
)
parser.add_argument( parser.add_argument(
"--attention-reduce-in-fp32", "--attention-reduce-in-fp32",
action="store_true", action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment