Unverified Commit 8e66fbec authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve DP attention (#4390)


Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
parent f141298a
from __future__ import annotations
import functools
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Union
import torch
......@@ -14,6 +16,8 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -86,6 +90,27 @@ def get_attention_dp_size():
return _DP_SIZE
@contextmanager
def disable_dp_size():
"""Patch the tp group temporarily until this function ends.
This method is for draft workers of speculative decoding to run draft model
with different tp degree from that of target model workers.
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global _DP_SIZE
assert _DP_SIZE is not None, "dp attention not initialized!"
old_dp_size = _DP_SIZE
_DP_SIZE = 1
try:
yield
finally:
_DP_SIZE = old_dp_size
def get_dp_local_info(forward_batch: ForwardBatch):
dp_rank = get_attention_dp_rank()
......@@ -159,7 +184,8 @@ def dp_gather(
layer_id != "embedding" or get_attention_tp_rank() == 0
):
assert (
global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
global_tokens.untyped_storage().data_ptr()
!= local_tokens.untyped_storage().data_ptr()
), "aliasing between global_tokens and local_tokens not allowed"
memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
......@@ -174,8 +200,9 @@ def dp_gather(
torch.ops.sglang.inplace_all_reduce(
global_tokens, group_name=get_tp_group().unique_name
)
else:
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
def dp_scatter(
......@@ -186,6 +213,7 @@ def dp_scatter(
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
# since local_tokens may be padded for cuda graph
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
local_tokens.fill_(0)
assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous()
......
......@@ -23,6 +23,7 @@ import triton.language as tl
from torch import nn
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
......
......@@ -54,7 +54,7 @@ class LoadBalanceMethod(Enum):
class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers."""
def __init__(self, server_args, port_args) -> None:
def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
# Parse args
self.max_total_num_tokens = None
self.server_args = server_args
......
......@@ -997,7 +997,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Handle DP attention
if self.server_args.enable_dp_attention:
ret = self.prepare_dp_attn_batch(ret)
ret, _ = self.prepare_dp_attn_batch(ret)
return ret
......@@ -1269,39 +1269,72 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
global_num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
global_num_tokens_for_logprob = num_tokens
else:
num_tokens = local_batch.extend_num_tokens
global_num_tokens_for_logprob = sum(
[
# We should have at least 1 token for sample in every case.
max(extend_len - logprob_start_len, 1)
for logprob_start_len, extend_len in zip(
local_batch.extend_logprob_start_lens, local_batch.extend_lens
)
]
)
if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
can_cuda_graph = 1
else:
can_cuda_graph = 0
if not self.spec_algorithm.is_none():
# TODO(sang): Support cuda graph when idle batch is there.
if local_batch is None or local_batch.forward_mode.is_idle():
can_cuda_graph = 0
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
is_extend_in_batch = (
local_batch.forward_mode.is_extend() if local_batch else False
)
local_info = torch.tensor(
[
num_tokens,
can_cuda_graph,
global_num_tokens_for_logprob,
is_extend_in_batch,
],
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 4),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
global_num_tokens,
local_num_tokens,
global_info.flatten(),
local_info,
group=self.tp_cpu_group,
)
global_num_tokens = global_info[:, 0, 0].tolist()
can_cuda_graph = min(global_info[:, 0, 1].tolist())
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
is_extend_in_batch = global_info[:, 0, 3].tolist()
if local_batch is None and global_num_tokens.max().item() > 0:
if local_batch is None and max(global_num_tokens) > 0:
local_batch = self.get_idle_batch()
if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist()
local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor(
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
dtype=torch.int32,
)
torch.distributed.all_reduce(
forward_mode_state,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_cpu_group,
)
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
local_batch.can_run_dp_cuda_graph = can_cuda_graph
return local_batch
return local_batch, any(is_extend_in_batch)
def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(
......
......@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
)
from sglang.srt.utils import is_hip
from sglang.srt.utils import get_available_gpu_memory, is_hip
_is_hip = is_hip()
......@@ -174,6 +174,7 @@ class CudaGraphRunner:
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size
......@@ -236,7 +237,7 @@ class CudaGraphRunner:
if self.enable_dp_attention:
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.dp_size,
self.max_bs * self.dp_size * self.num_tokens_per_bs,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
......@@ -276,13 +277,12 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min(
forward_batch.global_num_tokens_cpu
), max(forward_batch.global_num_tokens_cpu)
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
total_global_tokens in self.graphs
if self.disable_padding
else max_num_tokens <= self.max_bs
else total_global_tokens <= self.max_bs
)
else:
is_bs_supported = (
......@@ -304,6 +304,9 @@ class CudaGraphRunner:
def capture(self):
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
avail_mem = get_available_gpu_memory(
self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
)
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range = (
tqdm.tqdm(list(reversed(self.capture_bs)))
......@@ -311,6 +314,16 @@ class CudaGraphRunner:
else reversed(self.capture_bs)
)
for bs in capture_range:
if get_tensor_model_parallel_rank() == 0:
avail_mem = get_available_gpu_memory(
self.model_runner.device,
self.model_runner.gpu_id,
empty_cache=False,
)
capture_range.set_description(
f"Capturing batches ({avail_mem=:.2f} GB)"
)
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
......@@ -345,8 +358,18 @@ class CudaGraphRunner:
mrope_positions = self.mrope_positions[:, :bs]
if self.enable_dp_attention:
global_num_tokens = [bs] * self.tp_size
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < bs % self.dp_size)
for i in range(self.dp_size)
],
dtype=torch.int32,
device=input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
else:
global_num_tokens = None
gathered_buffer = None
......@@ -371,7 +394,7 @@ class CudaGraphRunner:
encoder_lens=encoder_lens,
return_logprob=False,
positions=positions,
global_num_tokens_cpu=global_num_tokens,
global_num_tokens_gpu=global_num_tokens,
gathered_buffer=gathered_buffer,
mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm,
......@@ -392,6 +415,9 @@ class CudaGraphRunner:
# Run and capture
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits, logits_output.hidden_states
......@@ -426,7 +452,7 @@ class CudaGraphRunner:
self.capture_hidden_mode = hidden_mode_from_spec_info
self.capture()
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
def replay_prepare(self, forward_batch: ForwardBatch):
self.recapture_if_needed(forward_batch)
raw_bs = forward_batch.batch_size
......@@ -435,7 +461,7 @@ class CudaGraphRunner:
# Pad
if self.enable_dp_attention:
index = bisect.bisect_left(
self.capture_bs, max(forward_batch.global_num_tokens_cpu)
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
......@@ -459,6 +485,8 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
if hasattr(forward_batch.spec_info, "hidden_states"):
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
......@@ -475,14 +503,29 @@ class CudaGraphRunner:
seq_lens_cpu=self.seq_lens_cpu,
)
# Store fields
self.raw_bs = raw_bs
self.raw_num_token = raw_num_token
self.bs = bs
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
if not skip_attn_backend_init:
self.replay_prepare(forward_batch)
else:
# In speculative decoding, these two fields are still needed.
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
# Replay
self.graphs[bs].replay()
next_token_logits, hidden_states = self.output_buffers[bs]
self.graphs[self.bs].replay()
next_token_logits, hidden_states = self.output_buffers[self.bs]
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits[:raw_num_token],
next_token_logits=next_token_logits[: self.raw_num_token],
hidden_states=(
hidden_states[:raw_num_token] if hidden_states is not None else None
hidden_states[: self.raw_num_token]
if hidden_states is not None
else None
),
)
return logits_output
......
......@@ -38,7 +38,7 @@ import triton
import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend, next_power_of_2
from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
......@@ -263,15 +263,24 @@ class ForwardBatch:
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
)
# For DP attention
if batch.global_num_tokens is not None:
ret.global_num_tokens_cpu = batch.global_num_tokens
max_len = max(ret.global_num_tokens_cpu)
ret.global_num_tokens_gpu = torch.tensor(
batch.global_num_tokens, dtype=torch.int64
).to(device, non_blocking=True)
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
batch.global_num_tokens_for_logprob, dtype=torch.int64
).to(device, non_blocking=True)
sum_len = sum(batch.global_num_tokens)
ret.gathered_buffer = torch.zeros(
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
(sum_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
device=device,
)
if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device)
return ret
......
......@@ -26,15 +26,20 @@ from transformers import PretrainedConfig
from vllm import _custom_ops as ops
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope,
)
from sglang.srt.layers.dp_attention import (
dp_gather,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
......@@ -230,6 +235,7 @@ class DeepseekV2Attention(nn.Module):
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
layer_id=None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
......@@ -241,10 +247,14 @@ class DeepseekV2Attention(nn.Module):
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.dp_size = get_attention_dp_size()
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
assert num_heads % attn_tp_size == 0
self.num_local_heads = num_heads // attn_tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
......@@ -272,6 +282,8 @@ class DeepseekV2Attention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
self.kv_a_proj_with_mqa = ReplicatedLinear(
......@@ -296,6 +308,9 @@ class DeepseekV2Attention(nn.Module):
bias=False,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
reduce_results=reduce_results,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope_wrapper(
......@@ -330,6 +345,12 @@ class DeepseekV2Attention(nn.Module):
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
assert (
not self.o_proj.reduce_results
), "short-circuiting allreduce will lead to hangs"
return hidden_states
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q)
......@@ -385,8 +406,8 @@ class DeepseekV2AttentionMLA(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
layer_id=None,
use_dp=False,
reduce_results: bool = True,
layer_id: int = None,
prefix: str = "",
) -> None:
super().__init__()
......@@ -398,96 +419,66 @@ class DeepseekV2AttentionMLA(nn.Module):
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.dp_size = get_attention_dp_size()
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads if use_dp else num_heads // tp_size
assert num_heads % attn_tp_size == 0
self.num_local_heads = num_heads // attn_tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if use_dp:
# For data parallel attention
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_a_proj", prefix),
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ReplicatedLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_b_proj", prefix),
)
else:
self.q_proj = ReplicatedLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_proj", prefix),
)
self.kv_b_proj = ReplicatedLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix),
)
# O projection.
self.o_proj = ReplicatedLinear(
self.num_heads * self.v_head_dim,
# For tensor parallel attention
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
prefix=add_prefix("q_a_proj", prefix),
)
else:
# For tensor parallel attention
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_a_proj", prefix),
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_b_proj", prefix),
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_proj", prefix),
)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix),
prefix=add_prefix("q_b_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
# O projection.
self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim,
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
prefix=add_prefix("q_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
# O projection.
self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=add_prefix("o_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
......@@ -542,38 +533,49 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_vc = None
self.w_scale = None
self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged"
]
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
if self.enable_flashinfer_mla:
# Flashinfer MLA: Do not absorb when enabling ragged prefill
return (
not self.flashinfer_mla_disable_ragged
and forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0
)
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
return (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
assert (
not self.o_proj.reduce_results
), "short-circuiting allreduce will lead to hangs"
return hidden_states
def no_absorb() -> bool:
if global_server_args_dict["enable_flashinfer_mla"]:
# Flashinfer MLA: Do not absorb when enabling ragged prefill
return (
not global_server_args_dict["flashinfer_mla_disable_ragged"]
and forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0
)
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
return (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.extend_prefix_lens.sum() == 0
)
if no_absorb():
if self.no_absorb(forward_batch):
return self.forward_normal(positions, hidden_states, forward_batch)
else:
if _is_hip:
if (
os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
self.rocm_fused_decode_mla
and forward_batch.forward_mode.is_decode()
):
return self.forward_absorb_fused_mla_rope(
......@@ -845,34 +847,6 @@ class DeepseekV2AttentionMLA(nn.Module):
return output
def all_gather(
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
):
all_lens = forward_batch.global_num_tokens_cpu
max_len = max(forward_batch.global_num_tokens_cpu)
if world_size == 1:
return input_tensor, 0, all_lens[0]
padded_tensor = torch.nn.functional.pad(
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
)
group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
gathered_tensors = torch.concat(
[
forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
for i in range(world_size)
]
)
start_index = 0 if rank == 0 else sum(all_lens[:rank])
end_index = start_index + all_lens[rank]
return gathered_tensors, start_index, end_index
class DeepseekV2DecoderLayer(nn.Module):
def __init__(
......@@ -888,14 +862,10 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.enable_dp_attention = (
not global_server_args_dict["disable_mla"]
and global_server_args_dict["enable_dp_attention"]
)
if self.enable_dp_attention:
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group()
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
self.layer_id = layer_id
self.dp_size = get_attention_dp_size()
if not global_server_args_dict["disable_mla"]:
self.self_attn = DeepseekV2AttentionMLA(
config=config,
......@@ -913,7 +883,7 @@ class DeepseekV2DecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
layer_id=layer_id,
use_dp=self.enable_dp_attention,
reduce_results=False,
prefix=add_prefix("self_attn", prefix),
)
else:
......@@ -933,8 +903,10 @@ class DeepseekV2DecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
layer_id=layer_id,
reduce_results=False,
prefix=add_prefix("self_attn", prefix),
)
if is_nextn or (
config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace
......@@ -965,33 +937,47 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Scatter
if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
# Self Attention
if not forward_batch.forward_mode.is_idle():
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
# Gather
if get_tensor_model_parallel_world_size() > 1:
# all gather and all reduce
if self.dp_size != 1:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather(
hidden_states, local_hidden_states, forward_batch, self.layer_id
)
else:
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
# Fully Connected
if self.enable_dp_attention:
hidden_states, start_idx, end_idx = all_gather(
hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states[start_idx:end_idx]
else:
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
......@@ -1027,12 +1013,27 @@ class DeepseekV2Model(nn.Module):
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dp_size = get_attention_dp_size()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
# Gather
if self.dp_size != 1:
input_ids, local_input_ids = (
torch.empty(
(forward_batch.gathered_buffer.shape[0],),
dtype=input_ids.dtype,
device=input_ids.device,
),
input_ids,
)
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
......@@ -1059,22 +1060,14 @@ class DeepseekV2ForCausalLM(nn.Module):
self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix)
)
if global_server_args_dict["enable_dp_attention"]:
self.lm_head = ReplicatedLinear(
config.hidden_size,
config.vocab_size,
bias=False,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
else:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size()
@torch.no_grad()
def forward(
......@@ -1084,6 +1077,16 @@ class DeepseekV2ForCausalLM(nn.Module):
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch)
if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
......
......@@ -262,14 +262,14 @@ class ServerArgs:
# Data parallelism attention
if self.enable_dp_attention:
self.dp_size = self.tp_size
assert self.tp_size % self.dp_size == 0
self.chunked_prefill_size = self.chunked_prefill_size // 2
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
assert (
self.dp_size > 1
), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
assert self.tp_size % self.dp_size == 0
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
"Data parallel size is adjusted to be the same as tensor parallel size. "
)
# Speculative Decoding
......
......@@ -25,6 +25,8 @@ class TestDPAttention(unittest.TestCase):
"--tp",
"2",
"--enable-dp-attention",
"--dp",
"2",
],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment