"vscode:/vscode.git/clone" did not exist on "79db3eb6ca669217e85f5e1cd5e584c6f375681a"
Unverified Commit 8e66fbec authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve DP attention (#4390)


Co-authored-by: default avatardhou-xai <dhou@x.ai>
Co-authored-by: default avatarSangBin Cho <rkooo567@gmail.com>
parent f141298a
from __future__ import annotations
import functools
import logging
from contextlib import contextmanager
from typing import TYPE_CHECKING, Union
import torch
......@@ -14,6 +16,8 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_reduce,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
......@@ -86,6 +90,27 @@ def get_attention_dp_size():
return _DP_SIZE
@contextmanager
def disable_dp_size():
"""Patch the tp group temporarily until this function ends.
This method is for draft workers of speculative decoding to run draft model
with different tp degree from that of target model workers.
Args:
tp_group (GroupCoordinator): the tp group coordinator
"""
global _DP_SIZE
assert _DP_SIZE is not None, "dp attention not initialized!"
old_dp_size = _DP_SIZE
_DP_SIZE = 1
try:
yield
finally:
_DP_SIZE = old_dp_size
def get_dp_local_info(forward_batch: ForwardBatch):
dp_rank = get_attention_dp_rank()
......@@ -159,7 +184,8 @@ def dp_gather(
layer_id != "embedding" or get_attention_tp_rank() == 0
):
assert (
global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
global_tokens.untyped_storage().data_ptr()
!= local_tokens.untyped_storage().data_ptr()
), "aliasing between global_tokens and local_tokens not allowed"
memcpy_triton(
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
......@@ -174,8 +200,9 @@ def dp_gather(
torch.ops.sglang.inplace_all_reduce(
global_tokens, group_name=get_tp_group().unique_name
)
else:
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
def dp_scatter(
......@@ -186,6 +213,7 @@ def dp_scatter(
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
# since local_tokens may be padded for cuda graph
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
local_tokens.fill_(0)
assert local_tokens.is_contiguous()
assert global_tokens.is_contiguous()
......
......@@ -23,6 +23,7 @@ import triton.language as tl
from torch import nn
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
)
......
......@@ -54,7 +54,7 @@ class LoadBalanceMethod(Enum):
class DataParallelController:
"""A controller that dispatches requests to multiple data parallel workers."""
def __init__(self, server_args, port_args) -> None:
def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
# Parse args
self.max_total_num_tokens = None
self.server_args = server_args
......
......@@ -997,7 +997,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Handle DP attention
if self.server_args.enable_dp_attention:
ret = self.prepare_dp_attn_batch(ret)
ret, _ = self.prepare_dp_attn_batch(ret)
return ret
......@@ -1269,39 +1269,72 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
global_num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
global_num_tokens_for_logprob = num_tokens
else:
num_tokens = local_batch.extend_num_tokens
global_num_tokens_for_logprob = sum(
[
# We should have at least 1 token for sample in every case.
max(extend_len - logprob_start_len, 1)
for logprob_start_len, extend_len in zip(
local_batch.extend_logprob_start_lens, local_batch.extend_lens
)
]
)
if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
can_cuda_graph = 1
else:
can_cuda_graph = 0
if not self.spec_algorithm.is_none():
# TODO(sang): Support cuda graph when idle batch is there.
if local_batch is None or local_batch.forward_mode.is_idle():
can_cuda_graph = 0
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
is_extend_in_batch = (
local_batch.forward_mode.is_extend() if local_batch else False
)
local_info = torch.tensor(
[
num_tokens,
can_cuda_graph,
global_num_tokens_for_logprob,
is_extend_in_batch,
],
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 4),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
global_num_tokens,
local_num_tokens,
global_info.flatten(),
local_info,
group=self.tp_cpu_group,
)
global_num_tokens = global_info[:, 0, 0].tolist()
can_cuda_graph = min(global_info[:, 0, 1].tolist())
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
is_extend_in_batch = global_info[:, 0, 3].tolist()
if local_batch is None and global_num_tokens.max().item() > 0:
if local_batch is None and max(global_num_tokens) > 0:
local_batch = self.get_idle_batch()
if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist()
local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor(
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
dtype=torch.int32,
)
torch.distributed.all_reduce(
forward_mode_state,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_cpu_group,
)
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
local_batch.can_run_dp_cuda_graph = can_cuda_graph
return local_batch
return local_batch, any(is_extend_in_batch)
def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(
......
......@@ -33,7 +33,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
)
from sglang.srt.utils import is_hip
from sglang.srt.utils import get_available_gpu_memory, is_hip
_is_hip = is_hip()
......@@ -174,6 +174,7 @@ class CudaGraphRunner:
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.enable_dp_attention = model_runner.server_args.enable_dp_attention
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size
......@@ -236,7 +237,7 @@ class CudaGraphRunner:
if self.enable_dp_attention:
self.gathered_buffer = torch.zeros(
(
self.max_bs * self.dp_size,
self.max_bs * self.dp_size * self.num_tokens_per_bs,
self.model_runner.model_config.hidden_size,
),
dtype=self.model_runner.dtype,
......@@ -276,13 +277,12 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min(
forward_batch.global_num_tokens_cpu
), max(forward_batch.global_num_tokens_cpu)
total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
total_global_tokens in self.graphs
if self.disable_padding
else max_num_tokens <= self.max_bs
else total_global_tokens <= self.max_bs
)
else:
is_bs_supported = (
......@@ -304,6 +304,9 @@ class CudaGraphRunner:
def capture(self):
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
avail_mem = get_available_gpu_memory(
self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
)
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range = (
tqdm.tqdm(list(reversed(self.capture_bs)))
......@@ -311,6 +314,16 @@ class CudaGraphRunner:
else reversed(self.capture_bs)
)
for bs in capture_range:
if get_tensor_model_parallel_rank() == 0:
avail_mem = get_available_gpu_memory(
self.model_runner.device,
self.model_runner.gpu_id,
empty_cache=False,
)
capture_range.set_description(
f"Capturing batches ({avail_mem=:.2f} GB)"
)
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
......@@ -345,8 +358,18 @@ class CudaGraphRunner:
mrope_positions = self.mrope_positions[:, :bs]
if self.enable_dp_attention:
global_num_tokens = [bs] * self.tp_size
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
self.global_num_tokens_gpu.copy_(
torch.tensor(
[
num_tokens // self.dp_size + (i < bs % self.dp_size)
for i in range(self.dp_size)
],
dtype=torch.int32,
device=input_ids.device,
)
)
global_num_tokens = self.global_num_tokens_gpu
gathered_buffer = self.gathered_buffer[:num_tokens]
else:
global_num_tokens = None
gathered_buffer = None
......@@ -371,7 +394,7 @@ class CudaGraphRunner:
encoder_lens=encoder_lens,
return_logprob=False,
positions=positions,
global_num_tokens_cpu=global_num_tokens,
global_num_tokens_gpu=global_num_tokens,
gathered_buffer=gathered_buffer,
mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm,
......@@ -392,6 +415,9 @@ class CudaGraphRunner:
# Run and capture
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits, logits_output.hidden_states
......@@ -426,7 +452,7 @@ class CudaGraphRunner:
self.capture_hidden_mode = hidden_mode_from_spec_info
self.capture()
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
def replay_prepare(self, forward_batch: ForwardBatch):
self.recapture_if_needed(forward_batch)
raw_bs = forward_batch.batch_size
......@@ -435,7 +461,7 @@ class CudaGraphRunner:
# Pad
if self.enable_dp_attention:
index = bisect.bisect_left(
self.capture_bs, max(forward_batch.global_num_tokens_cpu)
self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
......@@ -459,6 +485,8 @@ class CudaGraphRunner:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.enable_dp_attention:
self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
if hasattr(forward_batch.spec_info, "hidden_states"):
self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
......@@ -475,14 +503,29 @@ class CudaGraphRunner:
seq_lens_cpu=self.seq_lens_cpu,
)
# Store fields
self.raw_bs = raw_bs
self.raw_num_token = raw_num_token
self.bs = bs
def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
if not skip_attn_backend_init:
self.replay_prepare(forward_batch)
else:
# In speculative decoding, these two fields are still needed.
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
# Replay
self.graphs[bs].replay()
next_token_logits, hidden_states = self.output_buffers[bs]
self.graphs[self.bs].replay()
next_token_logits, hidden_states = self.output_buffers[self.bs]
logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits[:raw_num_token],
next_token_logits=next_token_logits[: self.raw_num_token],
hidden_states=(
hidden_states[:raw_num_token] if hidden_states is not None else None
hidden_states[: self.raw_num_token]
if hidden_states is not None
else None
),
)
return logits_output
......
......@@ -38,7 +38,7 @@ import triton
import triton.language as tl
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import get_compiler_backend, next_power_of_2
from sglang.srt.utils import get_compiler_backend
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
......@@ -263,15 +263,24 @@ class ForwardBatch:
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
)
# For DP attention
if batch.global_num_tokens is not None:
ret.global_num_tokens_cpu = batch.global_num_tokens
max_len = max(ret.global_num_tokens_cpu)
ret.global_num_tokens_gpu = torch.tensor(
batch.global_num_tokens, dtype=torch.int64
).to(device, non_blocking=True)
ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
batch.global_num_tokens_for_logprob, dtype=torch.int64
).to(device, non_blocking=True)
sum_len = sum(batch.global_num_tokens)
ret.gathered_buffer = torch.zeros(
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
(sum_len, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
device=device,
)
if ret.forward_mode.is_idle():
ret.positions = torch.empty((0,), device=device)
return ret
......
This diff is collapsed.
......@@ -262,14 +262,14 @@ class ServerArgs:
# Data parallelism attention
if self.enable_dp_attention:
self.dp_size = self.tp_size
assert self.tp_size % self.dp_size == 0
self.chunked_prefill_size = self.chunked_prefill_size // 2
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
assert (
self.dp_size > 1
), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
assert self.tp_size % self.dp_size == 0
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
"Data parallel size is adjusted to be the same as tensor parallel size. "
)
# Speculative Decoding
......
......@@ -25,6 +25,8 @@ class TestDPAttention(unittest.TestCase):
"--tp",
"2",
"--enable-dp-attention",
"--dp",
"2",
],
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment