Unverified Commit 8e66fbec authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve DP attention (#4390)


Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
parent f141298a
from __future__ import annotations from __future__ import annotations
import functools import functools
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
import torch import torch
...@@ -14,6 +16,8 @@ from sglang.srt.distributed import ( ...@@ -14,6 +16,8 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -86,6 +90,27 @@ def get_attention_dp_size(): ...@@ -86,6 +90,27 @@ def get_attention_dp_size():
return _DP_SIZE return _DP_SIZE
@contextmanager
def disable_dp_size():
"""Patch the tp group temporarily until this function ends.
This method is for draft workers of speculative decoding to run draft model
with different tp degree from that of target model workers.
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global _DP_SIZE
assert _DP_SIZE is not None, "dp attention not initialized!"
old_dp_size = _DP_SIZE
_DP_SIZE = 1
try:
yield
finally:
_DP_SIZE = old_dp_size
def get_dp_local_info(forward_batch: ForwardBatch): def get_dp_local_info(forward_batch: ForwardBatch):
dp_rank = get_attention_dp_rank() dp_rank = get_attention_dp_rank()
...@@ -159,7 +184,8 @@ def dp_gather( ...@@ -159,7 +184,8 @@ def dp_gather(
layer_id != "embedding" or get_attention_tp_rank() == 0 layer_id != "embedding" or get_attention_tp_rank() == 0
): ):
assert ( assert (
global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr() global_tokens.untyped_storage().data_ptr()
!= local_tokens.untyped_storage().data_ptr()
), "aliasing between global_tokens and local_tokens not allowed" ), "aliasing between global_tokens and local_tokens not allowed"
memcpy_triton( memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
...@@ -174,8 +200,9 @@ def dp_gather( ...@@ -174,8 +200,9 @@ def dp_gather(
torch.ops.sglang.inplace_all_reduce( torch.ops.sglang.inplace_all_reduce(
global_tokens, group_name=get_tp_group().unique_name global_tokens, group_name=get_tp_group().unique_name
) )
else: else:
global_tokens = tensor_model_parallel_all_reduce(global_tokens) global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
def dp_scatter( def dp_scatter(
...@@ -186,6 +213,7 @@ def dp_scatter( ...@@ -186,6 +213,7 @@ def dp_scatter(
# local_num_tokens is not necessarily the same as local_tokens.shape[0], # local_num_tokens is not necessarily the same as local_tokens.shape[0],
# since local_tokens may be padded for cuda graph # since local_tokens may be padded for cuda graph
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
local_tokens.fill_(0) local_tokens.fill_(0)
assert local_tokens.is_contiguous() assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous() assert global_tokens.is_contiguous()
......
...@@ -23,6 +23,7 @@ import triton.language as tl ...@@ -23,6 +23,7 @@ import triton.language as tl
from torch import nn from torch import nn
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
......
...@@ -54,7 +54,7 @@ class LoadBalanceMethod(Enum): ...@@ -54,7 +54,7 @@ class LoadBalanceMethod(Enum):
class DataParallelController: class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers.""" """A controller that dispatches requests to multiple data parallel workers."""
def __init__(self, server_args, port_args) -> None: def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
# Parse args # Parse args
self.max_total_num_tokens = None self.max_total_num_tokens = None
self.server_args = server_args self.server_args = server_args
......
...@@ -997,7 +997,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -997,7 +997,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Handle DP attention # Handle DP attention
if self.server_args.enable_dp_attention: if self.server_args.enable_dp_attention:
ret = self.prepare_dp_attn_batch(ret) ret, _ = self.prepare_dp_attn_batch(ret)
return ret return ret
...@@ -1269,39 +1269,72 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -1269,39 +1269,72 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Check if other DP workers have running batches # Check if other DP workers have running batches
if local_batch is None: if local_batch is None:
num_tokens = 0 num_tokens = 0
global_num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode(): elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size() num_tokens = local_batch.batch_size()
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
global_num_tokens_for_logprob = num_tokens
else: else:
num_tokens = local_batch.extend_num_tokens num_tokens = local_batch.extend_num_tokens
global_num_tokens_for_logprob = sum(
[
# We should have at least 1 token for sample in every case.
max(extend_len - logprob_start_len, 1)
for logprob_start_len, extend_len in zip(
local_batch.extend_logprob_start_lens, local_batch.extend_lens
)
]
)
if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
can_cuda_graph = 1
else:
can_cuda_graph = 0
if not self.spec_algorithm.is_none():
# TODO(sang): Support cuda graph when idle batch is there.
if local_batch is None or local_batch.forward_mode.is_idle():
can_cuda_graph = 0
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64) is_extend_in_batch = (
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64) local_batch.forward_mode.is_extend() if local_batch else False
)
local_info = torch.tensor(
[
num_tokens,
can_cuda_graph,
global_num_tokens_for_logprob,
is_extend_in_batch,
],
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 4),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
global_num_tokens, global_info.flatten(),
local_num_tokens, local_info,
group=self.tp_cpu_group, group=self.tp_cpu_group,
) )
global_num_tokens = global_info[:, 0, 0].tolist()
can_cuda_graph = min(global_info[:, 0, 1].tolist())
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
is_extend_in_batch = global_info[:, 0, 3].tolist()
if local_batch is None and global_num_tokens.max().item() > 0: if local_batch is None and max(global_num_tokens) > 0:
local_batch = self.get_idle_batch() local_batch = self.get_idle_batch()
if local_batch is not None: if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist() local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
# Check forward mode for cuda graph # Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph: if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor( local_batch.can_run_dp_cuda_graph = can_cuda_graph
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
dtype=torch.int32,
)
torch.distributed.all_reduce(
forward_mode_state,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_cpu_group,
)
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
return local_batch return local_batch, any(is_extend_in_batch)
def get_idle_batch(self): def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new( idle_batch = ScheduleBatch.init_new(
......
...@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.utils import is_hip from sglang.srt.utils import get_available_gpu_memory, is_hip
_is_hip = is_hip() _is_hip = is_hip()
...@@ -174,6 +174,7 @@ class CudaGraphRunner: ...@@ -174,6 +174,7 @@ class CudaGraphRunner:
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
self.tp_size = model_runner.server_args.tp_size self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size self.dp_size = model_runner.server_args.dp_size
...@@ -236,7 +237,7 @@ class CudaGraphRunner: ...@@ -236,7 +237,7 @@ class CudaGraphRunner:
if self.enable_dp_attention: if self.enable_dp_attention:
self.gathered_buffer = torch.zeros( self.gathered_buffer = torch.zeros(
( (
self.max_bs * self.dp_size, self.max_bs * self.dp_size * self.num_tokens_per_bs,
self.model_runner.model_config.hidden_size, self.model_runner.model_config.hidden_size,
), ),
dtype=self.model_runner.dtype, dtype=self.model_runner.dtype,
...@@ -276,13 +277,12 @@ class CudaGraphRunner: ...@@ -276,13 +277,12 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention: if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min( total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
forward_batch.global_num_tokens_cpu
), max(forward_batch.global_num_tokens_cpu)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs) total_global_tokens in self.graphs
if self.disable_padding if self.disable_padding
else max_num_tokens <= self.max_bs else total_global_tokens <= self.max_bs
) )
else: else:
is_bs_supported = ( is_bs_supported = (
...@@ -304,6 +304,9 @@ class CudaGraphRunner: ...@@ -304,6 +304,9 @@ class CudaGraphRunner:
def capture(self): def capture(self):
with graph_capture() as graph_capture_context: with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream self.stream = graph_capture_context.stream
avail_mem = get_available_gpu_memory(
self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
)
# Reverse the order to enable better memory sharing across cuda graphs. # Reverse the order to enable better memory sharing across cuda graphs.
capture_range = ( capture_range = (
tqdm.tqdm(list(reversed(self.capture_bs))) tqdm.tqdm(list(reversed(self.capture_bs)))
...@@ -311,6 +314,16 @@ class CudaGraphRunner: ...@@ -311,6 +314,16 @@ class CudaGraphRunner:
else reversed(self.capture_bs) else reversed(self.capture_bs)
) )
for bs in capture_range: for bs in capture_range:
if get_tensor_model_parallel_rank() == 0:
avail_mem = get_available_gpu_memory(
self.model_runner.device,
self.model_runner.gpu_id,
empty_cache=False,
)
capture_range.set_description(
f"Capturing batches ({avail_mem=:.2f} GB)"
)
with patch_model( with patch_model(
self.model_runner.model, self.model_runner.model,
bs in self.compile_bs, bs in self.compile_bs,
...@@ -345,8 +358,18 @@ class CudaGraphRunner: ...@@ -345,8 +358,18 @@ class CudaGraphRunner:
mrope_positions = self.mrope_positions[:, :bs] mrope_positions = self.mrope_positions[:, :bs]
if self.enable_dp_attention: if self.enable_dp_attention:
global_num_tokens = [bs] * self.tp_size self.global_num_tokens_gpu.copy_(
gathered_buffer = self.gathered_buffer[: bs * self.tp_size] torch.tensor(
[
num_tokens // self.dp_size + (i < bs % self.dp_size)
for i in range(self.dp_size)
],
dtype=torch.int32,
device=input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
else: else:
global_num_tokens = None global_num_tokens = None
gathered_buffer = None gathered_buffer = None
...@@ -371,7 +394,7 @@ class CudaGraphRunner: ...@@ -371,7 +394,7 @@ class CudaGraphRunner:
encoder_lens=encoder_lens, encoder_lens=encoder_lens,
return_logprob=False, return_logprob=False,
positions=positions, positions=positions,
global_num_tokens_cpu=global_num_tokens, global_num_tokens_gpu=global_num_tokens,
gathered_buffer=gathered_buffer, gathered_buffer=gathered_buffer,
mrope_positions=mrope_positions, mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
...@@ -392,6 +415,9 @@ class CudaGraphRunner: ...@@ -392,6 +415,9 @@ class CudaGraphRunner:
# Run and capture # Run and capture
def run_once(): def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
logits_output = forward(input_ids, forward_batch.positions, forward_batch) logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits, logits_output.hidden_states return logits_output.next_token_logits, logits_output.hidden_states
...@@ -426,7 +452,7 @@ class CudaGraphRunner: ...@@ -426,7 +452,7 @@ class CudaGraphRunner:
self.capture_hidden_mode = hidden_mode_from_spec_info self.capture_hidden_mode = hidden_mode_from_spec_info
self.capture() self.capture()
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False): def replay_prepare(self, forward_batch: ForwardBatch):
self.recapture_if_needed(forward_batch) self.recapture_if_needed(forward_batch)
raw_bs = forward_batch.batch_size raw_bs = forward_batch.batch_size
...@@ -435,7 +461,7 @@ class CudaGraphRunner: ...@@ -435,7 +461,7 @@ class CudaGraphRunner:
# Pad # Pad
if self.enable_dp_attention: if self.enable_dp_attention:
index = bisect.bisect_left( index = bisect.bisect_left(
self.capture_bs, max(forward_batch.global_num_tokens_cpu) self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
) )
else: else:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
...@@ -459,6 +485,8 @@ class CudaGraphRunner: ...@@ -459,6 +485,8 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None: if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
if hasattr(forward_batch.spec_info, "hidden_states"): if hasattr(forward_batch.spec_info, "hidden_states"):
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
...@@ -475,14 +503,29 @@ class CudaGraphRunner: ...@@ -475,14 +503,29 @@ class CudaGraphRunner:
seq_lens_cpu=self.seq_lens_cpu, seq_lens_cpu=self.seq_lens_cpu,
) )
# Store fields
self.raw_bs = raw_bs
self.raw_num_token = raw_num_token
self.bs = bs
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
if not skip_attn_backend_init:
self.replay_prepare(forward_batch)
else:
# In speculative decoding, these two fields are still needed.
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
# Replay # Replay
self.graphs[bs].replay() self.graphs[self.bs].replay()
next_token_logits, hidden_states = self.output_buffers[bs] next_token_logits, hidden_states = self.output_buffers[self.bs]
logits_output = LogitsProcessorOutput( logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits[:raw_num_token], next_token_logits=next_token_logits[: self.raw_num_token],
hidden_states=( hidden_states=(
hidden_states[:raw_num_token] if hidden_states is not None else None hidden_states[: self.raw_num_token]
if hidden_states is not None
else None
), ),
) )
return logits_output return logits_output
......
...@@ -38,7 +38,7 @@ import triton ...@@ -38,7 +38,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend, next_power_of_2 from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...@@ -263,15 +263,24 @@ class ForwardBatch: ...@@ -263,15 +263,24 @@ class ForwardBatch:
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu, extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
) )
# For DP attention
if batch.global_num_tokens is not None: if batch.global_num_tokens is not None:
ret.global_num_tokens_cpu = batch.global_num_tokens ret.global_num_tokens_cpu = batch.global_num_tokens
max_len = max(ret.global_num_tokens_cpu) ret.global_num_tokens_gpu = torch.tensor(
batch.global_num_tokens, dtype=torch.int64
).to(device, non_blocking=True)
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
batch.global_num_tokens_for_logprob, dtype=torch.int64
).to(device, non_blocking=True)
sum_len = sum(batch.global_num_tokens)
ret.gathered_buffer = torch.zeros( ret.gathered_buffer = torch.zeros(
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size), (sum_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype, dtype=model_runner.dtype,
device=device, device=device,
) )
if ret.forward_mode.is_idle(): if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device) ret.positions = torch.empty((0,), device=device)
return ret return ret
......
...@@ -26,15 +26,20 @@ from transformers import PretrainedConfig ...@@ -26,15 +26,20 @@ from transformers import PretrainedConfig
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import ( from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope, decode_attention_fwd_grouped_rope,
) )
from sglang.srt.layers.dp_attention import (
dp_gather,
dp_scatter,
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
ColumnParallelLinear, ColumnParallelLinear,
...@@ -230,6 +235,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -230,6 +235,7 @@ class DeepseekV2Attention(nn.Module):
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
layer_id=None, layer_id=None,
reduce_results: bool = True,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -241,10 +247,14 @@ class DeepseekV2Attention(nn.Module): ...@@ -241,10 +247,14 @@ class DeepseekV2Attention(nn.Module):
self.v_head_dim = v_head_dim self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
self.dp_size = get_attention_dp_size()
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.num_heads = num_heads self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size() assert num_heads % attn_tp_size == 0
assert num_heads % tp_size == 0 self.num_local_heads = num_heads // attn_tp_size
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5 self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
...@@ -272,6 +282,8 @@ class DeepseekV2Attention(nn.Module): ...@@ -272,6 +282,8 @@ class DeepseekV2Attention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_proj", prefix), prefix=add_prefix("q_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
) )
self.kv_a_proj_with_mqa = ReplicatedLinear( self.kv_a_proj_with_mqa = ReplicatedLinear(
...@@ -296,6 +308,9 @@ class DeepseekV2Attention(nn.Module): ...@@ -296,6 +308,9 @@ class DeepseekV2Attention(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("o_proj", prefix), prefix=add_prefix("o_proj", prefix),
reduce_results=reduce_results,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
) )
rope_scaling["rope_type"] = "deepseek_yarn" rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope_wrapper( self.rotary_emb = get_rope_wrapper(
...@@ -330,6 +345,12 @@ class DeepseekV2Attention(nn.Module): ...@@ -330,6 +345,12 @@ class DeepseekV2Attention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
if hidden_states.shape[0] == 0:
assert (
not self.o_proj.reduce_results
), "short-circuiting allreduce will lead to hangs"
return hidden_states
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0] q = self.q_a_proj(hidden_states)[0]
q = self.q_a_layernorm(q) q = self.q_a_layernorm(q)
...@@ -385,8 +406,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -385,8 +406,8 @@ class DeepseekV2AttentionMLA(nn.Module):
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192, max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
layer_id=None, reduce_results: bool = True,
use_dp=False, layer_id: int = None,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -398,56 +419,17 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -398,56 +419,17 @@ class DeepseekV2AttentionMLA(nn.Module):
self.v_head_dim = v_head_dim self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
self.dp_size = get_attention_dp_size()
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.num_heads = num_heads self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size() assert num_heads % attn_tp_size == 0
assert num_heads % tp_size == 0 self.num_local_heads = num_heads // attn_tp_size
self.num_local_heads = num_heads if use_dp else num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5 self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
if use_dp:
# For data parallel attention
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_a_proj", prefix),
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ReplicatedLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_b_proj", prefix),
)
else:
self.q_proj = ReplicatedLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=add_prefix("q_proj", prefix),
)
self.kv_b_proj = ReplicatedLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix),
)
# O projection.
self.o_proj = ReplicatedLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=add_prefix("o_proj", prefix),
)
else:
# For tensor parallel attention # For tensor parallel attention
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear( self.q_a_proj = ReplicatedLinear(
...@@ -464,6 +446,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -464,6 +446,8 @@ class DeepseekV2AttentionMLA(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_b_proj", prefix), prefix=add_prefix("q_b_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
) )
else: else:
self.q_proj = ColumnParallelLinear( self.q_proj = ColumnParallelLinear(
...@@ -472,6 +456,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -472,6 +456,8 @@ class DeepseekV2AttentionMLA(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("q_proj", prefix), prefix=add_prefix("q_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
) )
self.kv_b_proj = ColumnParallelLinear( self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank, self.kv_lora_rank,
...@@ -479,6 +465,8 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -479,6 +465,8 @@ class DeepseekV2AttentionMLA(nn.Module):
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=add_prefix("kv_b_proj", prefix), prefix=add_prefix("kv_b_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
) )
# O projection. # O projection.
self.o_proj = RowParallelLinear( self.o_proj = RowParallelLinear(
...@@ -486,7 +474,10 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -486,7 +474,10 @@ class DeepseekV2AttentionMLA(nn.Module):
self.hidden_size, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
reduce_results=reduce_results,
prefix=add_prefix("o_proj", prefix), prefix=add_prefix("o_proj", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
) )
self.kv_a_proj_with_mqa = ReplicatedLinear( self.kv_a_proj_with_mqa = ReplicatedLinear(
...@@ -542,18 +533,17 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -542,18 +533,17 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_vc = None self.w_vc = None
self.w_scale = None self.w_scale = None
def forward( self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
self, self.flashinfer_mla_disable_ragged = global_server_args_dict[
positions: torch.Tensor, "flashinfer_mla_disable_ragged"
hidden_states: torch.Tensor, ]
forward_batch: ForwardBatch, self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
) -> torch.Tensor:
def no_absorb() -> bool: def no_absorb(self, forward_batch: ForwardBatch) -> bool:
if global_server_args_dict["enable_flashinfer_mla"]: if self.enable_flashinfer_mla:
# Flashinfer MLA: Do not absorb when enabling ragged prefill # Flashinfer MLA: Do not absorb when enabling ragged prefill
return ( return (
not global_server_args_dict["flashinfer_mla_disable_ragged"] not self.flashinfer_mla_disable_ragged
and forward_batch.forward_mode.is_extend() and forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend() and not forward_batch.forward_mode.is_draft_extend()
...@@ -568,12 +558,24 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -568,12 +558,24 @@ class DeepseekV2AttentionMLA(nn.Module):
and forward_batch.extend_prefix_lens.sum() == 0 and forward_batch.extend_prefix_lens.sum() == 0
) )
if no_absorb(): def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
assert (
not self.o_proj.reduce_results
), "short-circuiting allreduce will lead to hangs"
return hidden_states
if self.no_absorb(forward_batch):
return self.forward_normal(positions, hidden_states, forward_batch) return self.forward_normal(positions, hidden_states, forward_batch)
else: else:
if _is_hip: if _is_hip:
if ( if (
os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" self.rocm_fused_decode_mla
and forward_batch.forward_mode.is_decode() and forward_batch.forward_mode.is_decode()
): ):
return self.forward_absorb_fused_mla_rope( return self.forward_absorb_fused_mla_rope(
...@@ -845,34 +847,6 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -845,34 +847,6 @@ class DeepseekV2AttentionMLA(nn.Module):
return output return output
def all_gather(
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
):
all_lens = forward_batch.global_num_tokens_cpu
max_len = max(forward_batch.global_num_tokens_cpu)
if world_size == 1:
return input_tensor, 0, all_lens[0]
padded_tensor = torch.nn.functional.pad(
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
)
group.all_gather_into_tensor(forward_batch.gathered_buffer, padded_tensor)
gathered_tensors = torch.concat(
[
forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
for i in range(world_size)
]
)
start_index = 0 if rank == 0 else sum(all_lens[:rank])
end_index = start_index + all_lens[rank]
return gathered_tensors, start_index, end_index
class DeepseekV2DecoderLayer(nn.Module): class DeepseekV2DecoderLayer(nn.Module):
def __init__( def __init__(
...@@ -888,14 +862,10 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -888,14 +862,10 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.enable_dp_attention = ( self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
not global_server_args_dict["disable_mla"] self.layer_id = layer_id
and global_server_args_dict["enable_dp_attention"] self.dp_size = get_attention_dp_size()
)
if self.enable_dp_attention:
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group()
if not global_server_args_dict["disable_mla"]: if not global_server_args_dict["disable_mla"]:
self.self_attn = DeepseekV2AttentionMLA( self.self_attn = DeepseekV2AttentionMLA(
config=config, config=config,
...@@ -913,7 +883,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -913,7 +883,7 @@ class DeepseekV2DecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
layer_id=layer_id, layer_id=layer_id,
use_dp=self.enable_dp_attention, reduce_results=False,
prefix=add_prefix("self_attn", prefix), prefix=add_prefix("self_attn", prefix),
) )
else: else:
...@@ -933,8 +903,10 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -933,8 +903,10 @@ class DeepseekV2DecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings, max_position_embeddings=max_position_embeddings,
quant_config=quant_config, quant_config=quant_config,
layer_id=layer_id, layer_id=layer_id,
reduce_results=False,
prefix=add_prefix("self_attn", prefix), prefix=add_prefix("self_attn", prefix),
) )
if is_nextn or ( if is_nextn or (
config.n_routed_experts is not None config.n_routed_experts is not None
and layer_id >= config.first_k_dense_replace and layer_id >= config.first_k_dense_replace
...@@ -965,33 +937,47 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -965,33 +937,47 @@ class DeepseekV2DecoderLayer(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention
if not forward_batch.forward_mode.is_idle():
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
else: else:
hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Scatter
if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
# Self Attention
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
forward_batch=forward_batch, forward_batch=forward_batch,
) )
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
# Fully Connected # Gather
if self.enable_dp_attention: if get_tensor_model_parallel_world_size() > 1:
hidden_states, start_idx, end_idx = all_gather( # all gather and all reduce
hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group if self.dp_size != 1:
hidden_states, local_hidden_states = (
forward_batch.gathered_buffer,
hidden_states,
)
dp_gather(
hidden_states, local_hidden_states, forward_batch, self.layer_id
) )
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states[start_idx:end_idx]
else: else:
hidden_states = self.mlp(hidden_states) hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
# Fully Connected
hidden_states = self.mlp(hidden_states)
return hidden_states, residual return hidden_states, residual
...@@ -1027,12 +1013,27 @@ class DeepseekV2Model(nn.Module): ...@@ -1027,12 +1013,27 @@ class DeepseekV2Model(nn.Module):
) )
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.dp_size = get_attention_dp_size()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
# Gather
if self.dp_size != 1:
input_ids, local_input_ids = (
torch.empty(
(forward_batch.gathered_buffer.shape[0],),
dtype=input_ids.dtype,
device=input_ids.device,
),
input_ids,
)
dp_gather(input_ids, local_input_ids, forward_batch, "embedding")
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
...@@ -1059,15 +1060,6 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1059,15 +1060,6 @@ class DeepseekV2ForCausalLM(nn.Module):
self.model = DeepseekV2Model( self.model = DeepseekV2Model(
config, quant_config, prefix=add_prefix("model", prefix) config, quant_config, prefix=add_prefix("model", prefix)
) )
if global_server_args_dict["enable_dp_attention"]:
self.lm_head = ReplicatedLinear(
config.hidden_size,
config.vocab_size,
bias=False,
prefix=add_prefix("lm_head", prefix),
)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
else:
self.lm_head = ParallelLMHead( self.lm_head = ParallelLMHead(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
...@@ -1075,6 +1067,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1075,6 +1067,7 @@ class DeepseekV2ForCausalLM(nn.Module):
prefix=add_prefix("lm_head", prefix), prefix=add_prefix("lm_head", prefix),
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.dp_size = get_attention_dp_size()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -1084,6 +1077,16 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1084,6 +1077,16 @@ class DeepseekV2ForCausalLM(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch) hidden_states = self.model(input_ids, positions, forward_batch)
if self.dp_size != 1:
# important: forward batch.gathered_buffer is used both after scatter and after gather.
# be careful about this!
hidden_states, global_hidden_states = (
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
hidden_states,
)
dp_scatter(hidden_states, global_hidden_states, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch input_ids, hidden_states, self.lm_head, forward_batch
) )
......
...@@ -262,14 +262,14 @@ class ServerArgs: ...@@ -262,14 +262,14 @@ class ServerArgs:
# Data parallelism attention # Data parallelism attention
if self.enable_dp_attention: if self.enable_dp_attention:
self.dp_size = self.tp_size
assert self.tp_size % self.dp_size == 0
self.chunked_prefill_size = self.chunked_prefill_size // 2
self.schedule_conservativeness = self.schedule_conservativeness * 0.3 self.schedule_conservativeness = self.schedule_conservativeness * 0.3
assert (
self.dp_size > 1
), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
assert self.tp_size % self.dp_size == 0
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
logger.warning( logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
"Data parallel size is adjusted to be the same as tensor parallel size. "
) )
# Speculative Decoding # Speculative Decoding
......
...@@ -25,6 +25,8 @@ class TestDPAttention(unittest.TestCase): ...@@ -25,6 +25,8 @@ class TestDPAttention(unittest.TestCase):
"--tp", "--tp",
"2", "2",
"--enable-dp-attention", "--enable-dp-attention",
"--dp",
"2",
], ],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment