Unverified Commit e1eae1fd authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support MLA for DeepSeek-V2 with Triton - step 1 (#905)

parent f4d9953d
File mode changed from 100644 to 100755
......@@ -57,6 +57,8 @@ def _fwd_kernel(
stride_buf_vh,
stride_req_to_tokens_b,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr,
......@@ -75,8 +77,10 @@ def _fwd_kernel(
cur_batch_req_idx = tl.load(B_req_idx + cur_seq)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
offs_m = tl.arange(0, BLOCK_M)
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
offs_q = (
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs
......@@ -85,10 +89,20 @@ def _fwd_kernel(
)
q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
offs_qpe = (
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_qbs
+ cur_head * stride_qh
+ offs_dpe[None, :]
)
qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0)
# stage1: compute scores with prefix
offs_n = tl.arange(0, BLOCK_N)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32)
deno = tl.zeros([BLOCK_M], dtype=tl.float32)
e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
......@@ -110,6 +124,18 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
if BLOCK_DPE > 0:
offs_kpe = (
offs_kv_loc[None, :] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Buffer + offs_kpe,
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe, kpe)
qk *= sm_scale
if logit_cap > 0:
......@@ -125,7 +151,7 @@ def _fwd_kernel(
offs_buf_v = (
offs_kv_loc[:, None] * stride_buf_vbs
+ cur_kv_head * stride_buf_vh
+ offs_d[None, :]
+ offs_dv[None, :]
)
v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
p = p.to(v.dtype)
......@@ -150,6 +176,21 @@ def _fwd_kernel(
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
if BLOCK_DPE > 0:
offs_kpe = (
(cur_seq_extend_start_contiguous + start_n + offs_n[None, :])
* stride_kbs
+ cur_kv_head * stride_kh
+ offs_dpe[:, None]
)
kpe = tl.load(
K_Extend + offs_kpe,
mask=mask_n[None, :],
other=0.0,
)
qk += tl.dot(qpe, kpe)
qk *= sm_scale
if logit_cap > 0:
......@@ -169,7 +210,7 @@ def _fwd_kernel(
offs_v = (
(cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs
+ cur_kv_head * stride_vh
+ offs_d[None, :]
+ offs_dv[None, :]
)
v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
p = p.to(v.dtype)
......@@ -181,7 +222,7 @@ def _fwd_kernel(
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_obs
+ cur_head * stride_oh
+ offs_d[None, :]
+ offs_dv[None, :]
)
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
......@@ -217,8 +258,17 @@ def extend_attention_fwd(
o_extend.shape[-1],
)
assert Lq == Lk and Lk == Lv and Lv == Lo
assert Lq in {16, 32, 64, 128, 256}
assert Lq == Lk and Lv == Lo
assert Lq in {16, 32, 64, 128, 256, 576}
assert Lv in {16, 32, 64, 128, 256, 512}
if Lq == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
else:
BLOCK_DMODEL = Lq
BLOCK_DPE = 0
BLOCK_DV = Lv
if CUDA_CAPABILITY[0] >= 8:
BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
......@@ -260,7 +310,9 @@ def extend_attention_fwd(
v_buffer.stride(0),
v_buffer.stride(1),
req_to_tokens.stride(0),
BLOCK_DMODEL=Lq,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
......
......@@ -38,16 +38,22 @@ class RadixAttention(nn.Module):
num_kv_heads: int,
layer_id: int,
logit_cap: int = -1,
v_head_dim: int = -1,
):
super().__init__()
self.tp_q_head_num = num_heads
self.tp_k_head_num = num_kv_heads
self.tp_v_head_num = num_kv_heads
self.head_dim = head_dim
self.qk_head_dim = head_dim
self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
self.scaling = scaling
self.layer_id = layer_id
if not global_server_args_dict.get("disable_flashinfer", False):
if (
not global_server_args_dict.get("disable_flashinfer", False)
and self.qk_head_dim == self.v_head_dim
):
self.extend_forward = self.extend_forward_flashinfer
self.decode_forward = self.decode_forward_flashinfer
else:
......@@ -57,13 +63,17 @@ class RadixAttention(nn.Module):
self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0
def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
if self.qk_head_dim != self.v_head_dim:
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
else:
o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata)
extend_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim),
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
k.contiguous(),
v.contiguous(),
o.view(-1, self.tp_q_head_num, self.head_dim),
o.view(-1, self.tp_q_head_num, self.v_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
input_metadata.req_to_token_pool.req_to_token,
......@@ -82,14 +92,17 @@ class RadixAttention(nn.Module):
return o
def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata):
o = torch.empty_like(q)
if self.qk_head_dim != self.v_head_dim:
o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim))
else:
o = torch.empty_like(q)
self.store_kv_cache(k, v, input_metadata)
token_attention_fwd(
q.view(-1, self.tp_q_head_num, self.head_dim),
q.view(-1, self.tp_q_head_num, self.qk_head_dim),
input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
o.view(-1, self.tp_q_head_num, self.head_dim),
o.view(-1, self.tp_q_head_num, self.v_head_dim),
input_metadata.req_to_token_pool.req_to_token,
input_metadata.req_pool_indices,
input_metadata.triton_start_loc,
......@@ -160,8 +173,8 @@ class RadixAttention(nn.Module):
return o.view(-1, self.tp_q_head_num * self.head_dim)
def forward(self, q, k, v, input_metadata: InputMetadata):
k = k.view(-1, self.tp_k_head_num, self.head_dim)
v = v.view(-1, self.tp_v_head_num, self.head_dim)
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
if input_metadata.forward_mode == ForwardMode.EXTEND:
return self.extend_forward(q, k, v, input_metadata)
......
......@@ -54,6 +54,7 @@ def _fwd_kernel_stage1(
att_stride_h,
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_N: tl.constexpr,
logit_cap: tl.constexpr,
):
......@@ -73,6 +74,10 @@ def _fwd_kernel_stage1(
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
off_qpe = cur_batch * stride_qbs + cur_head * stride_qh + offs_dpe
offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)
block_stard_index = start_n * BLOCK_N
......@@ -97,6 +102,19 @@ def _fwd_kernel_stage1(
other=0.0,
).to(REDUCE_TRITON_TYPE)
att_value = tl.sum(q[None, :] * k, 1)
if BLOCK_DPE > 0:
qpe = tl.load(Q + off_qpe + start_mark).to(REDUCE_TRITON_TYPE)
offs_buf_kpe = (
k_loc[:, None] * stride_buf_kbs
+ cur_kv_head * stride_buf_kh
+ offs_dpe[None, :]
)
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=offs_n_new[:, None] < cur_batch_end_index,
other=0.0,
).to(REDUCE_TRITON_TYPE)
att_value += tl.sum(qpe[None, :] * kpe, 1)
att_value *= sm_scale
if logit_cap > 0:
......@@ -192,7 +210,14 @@ def _token_att_m_fwd(
# shape constraints
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
assert Lq == Lk
assert Lk in {16, 32, 64, 128, 256}
assert Lk in {16, 32, 64, 128, 256, 576}
if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
else:
BLOCK_DMODEL = Lk
BLOCK_DPE = 0
batch, head_num = B_req_idx.shape[0], q.shape[1]
......@@ -220,7 +245,8 @@ def _token_att_m_fwd(
k_buffer.stride(1),
att_out.stride(0),
kv_group_num=kv_group_num,
BLOCK_DMODEL=Lk,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_N=BLOCK,
logit_cap=logit_cap,
num_warps=num_warps,
......
......@@ -29,7 +29,7 @@ from sglang.global_config import global_config
from sglang.srt.constrained import RegexGuide
from sglang.srt.constrained.jump_forward import JumpForwardMap
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.mem_cache.radix_cache import RadixCache
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
......@@ -39,6 +39,7 @@ global_server_args_dict = {
"disable_flashinfer": False,
"disable_flashinfer_sampling": False,
"attention_reduce_in_fp32": False,
"enable_mla": False,
}
......@@ -289,7 +290,7 @@ class Batch:
# Request, memory pool, and cache
reqs: List[Req]
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool
token_to_kv_pool: BaseTokenToKVPool
tree_cache: RadixCache
# Batched arguments to model runner
......@@ -780,7 +781,7 @@ class InputMetadata:
seq_lens: torch.Tensor
positions: torch.Tensor
req_to_token_pool: ReqToTokenPool
token_to_kv_pool: TokenToKVPool
token_to_kv_pool: BaseTokenToKVPool
# For extend
extend_seq_lens: torch.Tensor
......
......@@ -57,32 +57,18 @@ class ReqToTokenPool:
self.can_use_mem_size = len(self.mem_state)
class TokenToKVPool:
class BaseTokenToKVPool:
"""A memory pool that maps a token to its kv cache locations"""
def __init__(
self,
size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
):
self.size = size
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
# Prefetch buffer
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
self.prefetch_chunk_size = 512
......@@ -90,15 +76,6 @@ class TokenToKVPool:
self.can_use_mem_size = self.size
self.clear()
def get_key_buffer(self, layer_id: int):
return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
return self.v_buffer[layer_id]
def get_kv_buffer(self, layer_id: int):
return self.k_buffer[layer_id], self.v_buffer[layer_id]
def available_size(self):
return self.can_use_mem_size + len(self.prefetch_buffer)
......@@ -139,3 +116,67 @@ class TokenToKVPool:
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state[0] = False
class MHATokenToKVPool(BaseTokenToKVPool):
def __init__(
self,
size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
):
super().__init__(size)
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id: int):
return self.k_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
return self.v_buffer[layer_id]
def get_kv_buffer(self, layer_id: int):
return self.k_buffer[layer_id], self.v_buffer[layer_id]
class MLATokenToKVPool(BaseTokenToKVPool):
def __init__(
self,
size: int,
dtype: torch.dtype,
kv_lora_rank: int,
qk_rope_head_dim: int,
layer_num: int,
):
super().__init__(size)
self.kv_lora_rank = kv_lora_rank
self.kv_buffer = [
torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
dtype=dtype,
device="cuda",
)
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id: int):
return self.kv_buffer[layer_id]
def get_value_buffer(self, layer_id: int):
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
def get_kv_buffer(self, layer_id: int):
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
from enum import IntEnum, auto
from typing import Optional
from transformers import PretrainedConfig
......@@ -20,6 +21,11 @@ from transformers import PretrainedConfig
from sglang.srt.hf_transformers_utils import get_config, get_context_length
class AttentionArch(IntEnum):
MLA = auto()
MHA = auto()
class ModelConfig:
def __init__(
self,
......@@ -55,6 +61,11 @@ class ModelConfig:
# FIXME: temporary special judge for deepseek v2 MLA architecture
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
else:
self.attention_arch = AttentionArch.MHA
self.num_attention_heads = self.hf_config.num_attention_heads
self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
......
......@@ -47,7 +47,12 @@ from sglang.srt.managers.schedule_batch import (
InputMetadata,
global_server_args_dict,
)
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.mem_cache.memory_pool import (
MHATokenToKVPool,
MLATokenToKVPool,
ReqToTokenPool,
)
from sglang.srt.model_config import AttentionArch
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_available_gpu_memory,
......@@ -86,6 +91,7 @@ class ModelRunner:
"disable_flashinfer": server_args.disable_flashinfer,
"disable_flashinfer_sampling": server_args.disable_flashinfer_sampling,
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
"enable_mla": server_args.enable_mla,
}
)
......@@ -193,15 +199,23 @@ class ModelRunner:
available_gpu_memory = get_available_gpu_memory(
self.gpu_id, distributed=self.tp_size > 1
)
head_dim = self.model_config.head_dim
head_num = self.model_config.get_num_kv_heads(self.tp_size)
cell_size = (
head_num
* head_dim
* self.model_config.num_hidden_layers
* 2
* torch._utils._element_size(self.dtype)
)
if (
self.model_config.attention_arch == AttentionArch.MLA
and self.server_args.enable_mla
):
cell_size = (
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
* self.model_config.num_hidden_layers
* torch._utils._element_size(self.dtype)
)
else:
cell_size = (
self.model_config.get_num_kv_heads(self.tp_size)
* self.model_config.head_dim
* self.model_config.num_hidden_layers
* 2
* torch._utils._element_size(self.dtype)
)
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
......@@ -241,13 +255,28 @@ class ModelRunner:
max_num_reqs,
self.model_config.context_len + 8,
)
self.token_to_kv_pool = TokenToKVPool(
self.max_total_num_tokens,
dtype=self.dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
)
if (
self.model_config.attention_arch == AttentionArch.MLA
and self.server_args.enable_mla
):
self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens,
dtype=self.dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers,
)
logger.info("using MLA Triton implementaion, flashinfer is disabled")
# FIXME: temporarily only Triton MLA is supported
self.server_args.disable_flashinfer = True
else:
self.token_to_kv_pool = MHATokenToKVPool(
self.max_total_num_tokens,
dtype=self.dtype,
head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
)
logger.info(
f"[gpu={self.gpu_id}] Memory pool end. "
f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
......
......@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.model_runner import InputMetadata
......@@ -312,6 +313,165 @@ class DeepseekV2Attention(nn.Module):
return output
class DeepseekV2AttentionMLA(nn.Module):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: int,
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
layer_id=None,
) -> None:
super().__init__()
self.layer_id = layer_id
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
)
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
)
# O projection.
self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
)
rope_scaling["type"] = "deepseek_yarn"
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False,
)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.attn = RadixAttention(
self.num_local_heads,
self.kv_lora_rank + self.qk_rope_head_dim,
self.scaling,
num_kv_heads=1,
layer_id=layer_id,
v_head_dim=self.kv_lora_rank,
)
kv_b_proj = self.kv_b_proj
w_kc, w_vc = kv_b_proj.weight.unflatten(
0, (-1, qk_nope_head_dim + v_head_dim)
).split([qk_nope_head_dim, v_head_dim], dim=1)
self.w_kc = w_kc
self.w_vc = w_vc
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty(
q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim
)
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_nope_out = q_input[..., : self.kv_lora_rank]
torch.bmm(q_nope.transpose(0, 1), self.w_kc, out=q_nope_out.transpose(0, 1))
k_input = self.kv_a_proj_with_mqa(hidden_states)[0].unsqueeze(1)
k_pe = k_input[..., self.kv_lora_rank :]
v_input = k_input[..., : self.kv_lora_rank]
v_input = self.kv_a_layernorm(v_input.contiguous())
k_input[..., : self.kv_lora_rank] = v_input
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
q_input[..., self.kv_lora_rank :] = q_pe
k_input[..., self.kv_lora_rank :] = k_pe
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
attn_bmm_output = attn_output.new_empty(
q_len, self.num_local_heads, self.v_head_dim
)
torch.bmm(
attn_output.transpose(0, 1),
self.w_vc.transpose(1, 2).contiguous(),
out=attn_bmm_output.transpose(0, 1),
)
attn_output = attn_bmm_output.flatten(1, 2)
output, _ = self.o_proj(attn_output)
return output
class DeepseekV2DecoderLayer(nn.Module):
def __init__(
......@@ -326,22 +486,44 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = DeepseekV2Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
layer_id=layer_id,
)
if global_server_args_dict["enable_mla"]:
self.self_attn = DeepseekV2AttentionMLA(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=(
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
),
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
layer_id=layer_id,
)
else:
self.self_attn = DeepseekV2Attention(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=(
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
),
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
layer_id=layer_id,
)
if (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
......
......@@ -80,6 +80,7 @@ class ServerArgs:
disable_disk_cache: bool = False
enable_torch_compile: bool = False
enable_p2p_check: bool = False
enable_mla: bool = False
attention_reduce_in_fp32: bool = False
efficient_weight_load: bool = False
......@@ -393,6 +394,11 @@ class ServerArgs:
action="store_true",
help="Enable P2P check for GPU access, otherwise the p2p access is allowed by default.",
)
parser.add_argument(
"--enable-mla",
action="store_true",
help="Enable Multi-head Latent Attention (MLA) for DeepSeek-V2",
)
parser.add_argument(
"--attention-reduce-in-fp32",
action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment