"git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "d8b5d1bef2affcec5228d0ed3a04ea47a8ae7cba"
Unverified Commit 8e66fbec authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve DP attention (#4390)


Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
parent f141298a
from __future__ import annotations from __future__ import annotations
import functools import functools
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
import torch import torch
...@@ -14,6 +16,8 @@ from sglang.srt.distributed import ( ...@@ -14,6 +16,8 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
...@@ -86,6 +90,27 @@ def get_attention_dp_size(): ...@@ -86,6 +90,27 @@ def get_attention_dp_size():
return _DP_SIZE return _DP_SIZE
@contextmanager
def disable_dp_size():
"""Patch the tp group temporarily until this function ends.
This method is for draft workers of speculative decoding to run draft model
with different tp degree from that of target model workers.
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global _DP_SIZE
assert _DP_SIZE is not None, "dp attention not initialized!"
old_dp_size = _DP_SIZE
_DP_SIZE = 1
try:
yield
finally:
_DP_SIZE = old_dp_size
def get_dp_local_info(forward_batch: ForwardBatch): def get_dp_local_info(forward_batch: ForwardBatch):
dp_rank = get_attention_dp_rank() dp_rank = get_attention_dp_rank()
...@@ -159,7 +184,8 @@ def dp_gather( ...@@ -159,7 +184,8 @@ def dp_gather(
layer_id != "embedding" or get_attention_tp_rank() == 0 layer_id != "embedding" or get_attention_tp_rank() == 0
): ):
assert ( assert (
global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr() global_tokens.untyped_storage().data_ptr()
!= local_tokens.untyped_storage().data_ptr()
), "aliasing between global_tokens and local_tokens not allowed" ), "aliasing between global_tokens and local_tokens not allowed"
memcpy_triton( memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
...@@ -174,8 +200,9 @@ def dp_gather( ...@@ -174,8 +200,9 @@ def dp_gather(
torch.ops.sglang.inplace_all_reduce( torch.ops.sglang.inplace_all_reduce(
global_tokens, group_name=get_tp_group().unique_name global_tokens, group_name=get_tp_group().unique_name
) )
else: else:
global_tokens = tensor_model_parallel_all_reduce(global_tokens) global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
def dp_scatter( def dp_scatter(
...@@ -186,6 +213,7 @@ def dp_scatter( ...@@ -186,6 +213,7 @@ def dp_scatter(
# local_num_tokens is not necessarily the same as local_tokens.shape[0], # local_num_tokens is not necessarily the same as local_tokens.shape[0],
# since local_tokens may be padded for cuda graph # since local_tokens may be padded for cuda graph
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch) local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
local_tokens.fill_(0) local_tokens.fill_(0)
assert local_tokens.is_contiguous() assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous() assert global_tokens.is_contiguous()
......
...@@ -23,6 +23,7 @@ import triton.language as tl ...@@ -23,6 +23,7 @@ import triton.language as tl
from torch import nn from torch import nn
from sglang.srt.distributed import ( from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
......
...@@ -54,7 +54,7 @@ class LoadBalanceMethod(Enum): ...@@ -54,7 +54,7 @@ class LoadBalanceMethod(Enum):
class DataParallelController: class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers.""" """A controller that dispatches requests to multiple data parallel workers."""
def __init__(self, server_args, port_args) -> None: def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
# Parse args # Parse args
self.max_total_num_tokens = None self.max_total_num_tokens = None
self.server_args = server_args self.server_args = server_args
......
...@@ -997,7 +997,7 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -997,7 +997,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Handle DP attention # Handle DP attention
if self.server_args.enable_dp_attention: if self.server_args.enable_dp_attention:
ret = self.prepare_dp_attn_batch(ret) ret, _ = self.prepare_dp_attn_batch(ret)
return ret return ret
...@@ -1269,39 +1269,72 @@ class Scheduler(SchedulerOutputProcessorMixin): ...@@ -1269,39 +1269,72 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Check if other DP workers have running batches # Check if other DP workers have running batches
if local_batch is None: if local_batch is None:
num_tokens = 0 num_tokens = 0
global_num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode(): elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size() num_tokens = local_batch.batch_size()
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
global_num_tokens_for_logprob = num_tokens
else: else:
num_tokens = local_batch.extend_num_tokens num_tokens = local_batch.extend_num_tokens
global_num_tokens_for_logprob = sum(
[
# We should have at least 1 token for sample in every case.
max(extend_len - logprob_start_len, 1)
for logprob_start_len, extend_len in zip(
local_batch.extend_logprob_start_lens, local_batch.extend_lens
)
]
)
if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
can_cuda_graph = 1
else:
can_cuda_graph = 0
if not self.spec_algorithm.is_none():
# TODO(sang): Support cuda graph when idle batch is there.
if local_batch is None or local_batch.forward_mode.is_idle():
can_cuda_graph = 0
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64) is_extend_in_batch = (
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64) local_batch.forward_mode.is_extend() if local_batch else False
)
local_info = torch.tensor(
[
num_tokens,
can_cuda_graph,
global_num_tokens_for_logprob,
is_extend_in_batch,
],
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 4),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor( torch.distributed.all_gather_into_tensor(
global_num_tokens, global_info.flatten(),
local_num_tokens, local_info,
group=self.tp_cpu_group, group=self.tp_cpu_group,
) )
global_num_tokens = global_info[:, 0, 0].tolist()
can_cuda_graph = min(global_info[:, 0, 1].tolist())
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
is_extend_in_batch = global_info[:, 0, 3].tolist()
if local_batch is None and global_num_tokens.max().item() > 0: if local_batch is None and max(global_num_tokens) > 0:
local_batch = self.get_idle_batch() local_batch = self.get_idle_batch()
if local_batch is not None: if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist() local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
# Check forward mode for cuda graph # Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph: if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor( local_batch.can_run_dp_cuda_graph = can_cuda_graph
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
dtype=torch.int32,
)
torch.distributed.all_reduce(
forward_mode_state,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_cpu_group,
)
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
return local_batch return local_batch, any(is_extend_in_batch)
def get_idle_batch(self): def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new( idle_batch = ScheduleBatch.init_new(
......
...@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch, ForwardBatch,
ForwardMode, ForwardMode,
) )
from sglang.srt.utils import is_hip from sglang.srt.utils import get_available_gpu_memory, is_hip
_is_hip = is_hip() _is_hip = is_hip()
...@@ -174,6 +174,7 @@ class CudaGraphRunner: ...@@ -174,6 +174,7 @@ class CudaGraphRunner:
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
self.tp_size = model_runner.server_args.tp_size self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size self.dp_size = model_runner.server_args.dp_size
...@@ -236,7 +237,7 @@ class CudaGraphRunner: ...@@ -236,7 +237,7 @@ class CudaGraphRunner:
if self.enable_dp_attention: if self.enable_dp_attention:
self.gathered_buffer = torch.zeros( self.gathered_buffer = torch.zeros(
( (
self.max_bs * self.dp_size, self.max_bs * self.dp_size * self.num_tokens_per_bs,
self.model_runner.model_config.hidden_size, self.model_runner.model_config.hidden_size,
), ),
dtype=self.model_runner.dtype, dtype=self.model_runner.dtype,
...@@ -276,13 +277,12 @@ class CudaGraphRunner: ...@@ -276,13 +277,12 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention: if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min( total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
forward_batch.global_num_tokens_cpu
), max(forward_batch.global_num_tokens_cpu)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs) total_global_tokens in self.graphs
if self.disable_padding if self.disable_padding
else max_num_tokens <= self.max_bs else total_global_tokens <= self.max_bs
) )
else: else:
is_bs_supported = ( is_bs_supported = (
...@@ -304,6 +304,9 @@ class CudaGraphRunner: ...@@ -304,6 +304,9 @@ class CudaGraphRunner:
def capture(self): def capture(self):
with graph_capture() as graph_capture_context: with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream self.stream = graph_capture_context.stream
avail_mem = get_available_gpu_memory(
self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
)
# Reverse the order to enable better memory sharing across cuda graphs. # Reverse the order to enable better memory sharing across cuda graphs.
capture_range = ( capture_range = (
tqdm.tqdm(list(reversed(self.capture_bs))) tqdm.tqdm(list(reversed(self.capture_bs)))
...@@ -311,6 +314,16 @@ class CudaGraphRunner: ...@@ -311,6 +314,16 @@ class CudaGraphRunner:
else reversed(self.capture_bs) else reversed(self.capture_bs)
) )
for bs in capture_range: for bs in capture_range:
if get_tensor_model_parallel_rank() == 0:
avail_mem = get_available_gpu_memory(
self.model_runner.device,
self.model_runner.gpu_id,
empty_cache=False,
)
capture_range.set_description(
f"Capturing batches ({avail_mem=:.2f} GB)"
)
with patch_model( with patch_model(
self.model_runner.model, self.model_runner.model,
bs in self.compile_bs, bs in self.compile_bs,
...@@ -345,8 +358,18 @@ class CudaGraphRunner: ...@@ -345,8 +358,18 @@ class CudaGraphRunner:
mrope_positions = self.mrope_positions[:, :bs] mrope_positions = self.mrope_positions[:, :bs]
if self.enable_dp_attention: if self.enable_dp_attention:
global_num_tokens = [bs] * self.tp_size self.global_num_tokens_gpu.copy_(
gathered_buffer = self.gathered_buffer[: bs * self.tp_size] torch.tensor(
[
num_tokens // self.dp_size + (i < bs % self.dp_size)
for i in range(self.dp_size)
],
dtype=torch.int32,
device=input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
else: else:
global_num_tokens = None global_num_tokens = None
gathered_buffer = None gathered_buffer = None
...@@ -371,7 +394,7 @@ class CudaGraphRunner: ...@@ -371,7 +394,7 @@ class CudaGraphRunner:
encoder_lens=encoder_lens, encoder_lens=encoder_lens,
return_logprob=False, return_logprob=False,
positions=positions, positions=positions,
global_num_tokens_cpu=global_num_tokens, global_num_tokens_gpu=global_num_tokens,
gathered_buffer=gathered_buffer, gathered_buffer=gathered_buffer,
mrope_positions=mrope_positions, mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm, spec_algorithm=self.model_runner.spec_algorithm,
...@@ -392,6 +415,9 @@ class CudaGraphRunner: ...@@ -392,6 +415,9 @@ class CudaGraphRunner:
# Run and capture # Run and capture
def run_once(): def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
logits_output = forward(input_ids, forward_batch.positions, forward_batch) logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits, logits_output.hidden_states return logits_output.next_token_logits, logits_output.hidden_states
...@@ -426,7 +452,7 @@ class CudaGraphRunner: ...@@ -426,7 +452,7 @@ class CudaGraphRunner:
self.capture_hidden_mode = hidden_mode_from_spec_info self.capture_hidden_mode = hidden_mode_from_spec_info
self.capture() self.capture()
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False): def replay_prepare(self, forward_batch: ForwardBatch):
self.recapture_if_needed(forward_batch) self.recapture_if_needed(forward_batch)
raw_bs = forward_batch.batch_size raw_bs = forward_batch.batch_size
...@@ -435,7 +461,7 @@ class CudaGraphRunner: ...@@ -435,7 +461,7 @@ class CudaGraphRunner:
# Pad # Pad
if self.enable_dp_attention: if self.enable_dp_attention:
index = bisect.bisect_left( index = bisect.bisect_left(
self.capture_bs, max(forward_batch.global_num_tokens_cpu) self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
) )
else: else:
index = bisect.bisect_left(self.capture_bs, raw_bs) index = bisect.bisect_left(self.capture_bs, raw_bs)
...@@ -459,6 +485,8 @@ class CudaGraphRunner: ...@@ -459,6 +485,8 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None: if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
if hasattr(forward_batch.spec_info, "hidden_states"): if hasattr(forward_batch.spec_info, "hidden_states"):
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
...@@ -475,14 +503,29 @@ class CudaGraphRunner: ...@@ -475,14 +503,29 @@ class CudaGraphRunner:
seq_lens_cpu=self.seq_lens_cpu, seq_lens_cpu=self.seq_lens_cpu,
) )
# Store fields
self.raw_bs = raw_bs
self.raw_num_token = raw_num_token
self.bs = bs
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
if not skip_attn_backend_init:
self.replay_prepare(forward_batch)
else:
# In speculative decoding, these two fields are still needed.
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
# Replay # Replay
self.graphs[bs].replay() self.graphs[self.bs].replay()
next_token_logits, hidden_states = self.output_buffers[bs] next_token_logits, hidden_states = self.output_buffers[self.bs]
logits_output = LogitsProcessorOutput( logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits[:raw_num_token], next_token_logits=next_token_logits[: self.raw_num_token],
hidden_states=( hidden_states=(
hidden_states[:raw_num_token] if hidden_states is not None else None hidden_states[: self.raw_num_token]
if hidden_states is not None
else None
), ),
) )
return logits_output return logits_output
......
...@@ -38,7 +38,7 @@ import triton ...@@ -38,7 +38,7 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend, next_power_of_2 from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
...@@ -263,15 +263,24 @@ class ForwardBatch: ...@@ -263,15 +263,24 @@ class ForwardBatch:
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu, extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
) )
# For DP attention
if batch.global_num_tokens is not None: if batch.global_num_tokens is not None:
ret.global_num_tokens_cpu = batch.global_num_tokens ret.global_num_tokens_cpu = batch.global_num_tokens
max_len = max(ret.global_num_tokens_cpu) ret.global_num_tokens_gpu = torch.tensor(
batch.global_num_tokens, dtype=torch.int64
).to(device, non_blocking=True)
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
batch.global_num_tokens_for_logprob, dtype=torch.int64
).to(device, non_blocking=True)
sum_len = sum(batch.global_num_tokens)
ret.gathered_buffer = torch.zeros( ret.gathered_buffer = torch.zeros(
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size), (sum_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype, dtype=model_runner.dtype,
device=device, device=device,
) )
if ret.forward_mode.is_idle(): if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device) ret.positions = torch.empty((0,), device=device)
return ret return ret
......
This diff is collapsed.
...@@ -262,14 +262,14 @@ class ServerArgs: ...@@ -262,14 +262,14 @@ class ServerArgs:
# Data parallelism attention # Data parallelism attention
if self.enable_dp_attention: if self.enable_dp_attention:
self.dp_size = self.tp_size
assert self.tp_size % self.dp_size == 0
self.chunked_prefill_size = self.chunked_prefill_size // 2
self.schedule_conservativeness = self.schedule_conservativeness * 0.3 self.schedule_conservativeness = self.schedule_conservativeness * 0.3
assert (
self.dp_size > 1
), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
assert self.tp_size % self.dp_size == 0
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
logger.warning( logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
"Data parallel size is adjusted to be the same as tensor parallel size. "
) )
# Speculative Decoding # Speculative Decoding
......
...@@ -25,6 +25,8 @@ class TestDPAttention(unittest.TestCase): ...@@ -25,6 +25,8 @@ class TestDPAttention(unittest.TestCase):
"--tp", "--tp",
"2", "2",
"--enable-dp-attention", "--enable-dp-attention",
"--dp",
"2",
], ],
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment