"docs/vscode:/vscode.git/clone" did not exist on "a786ca0cf2d2dd790cecd96b19cc478d7e661cb5"
Unverified Commit 976bc302 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Support DP MLA (#1970)

parent 2f2e0743
...@@ -244,6 +244,7 @@ jobs: ...@@ -244,6 +244,7 @@ jobs:
cd test/srt cd test/srt
python3 test_mla.py python3 test_mla.py
python3 test_mla_fp8.py python3 test_mla_fp8.py
python3 test_dp_attention.py
- name: Evaluate data parallelism accuracy (DP=2) - name: Evaluate data parallelism accuracy (DP=2)
timeout-minutes: 10 timeout-minutes: 10
......
...@@ -28,9 +28,13 @@ class TritonAttnBackend(AttentionBackend): ...@@ -28,9 +28,13 @@ class TritonAttnBackend(AttentionBackend):
self.decode_attention_fwd = decode_attention_fwd self.decode_attention_fwd = decode_attention_fwd
self.extend_attention_fwd = extend_attention_fwd self.extend_attention_fwd = extend_attention_fwd
self.num_head = (
model_runner.model_config.num_attention_heads // model_runner.tp_size if model_runner.server_args.enable_dp_attention:
) self.num_head = model_runner.model_config.num_attention_heads
else:
self.num_head = (
model_runner.model_config.num_attention_heads // model_runner.tp_size
)
if global_server_args_dict.get("triton_attention_reduce_in_fp32", False): if global_server_args_dict.get("triton_attention_reduce_in_fp32", False):
self.reduce_dtype = torch.float32 self.reduce_dtype = torch.float32
......
...@@ -81,20 +81,34 @@ class DataParallelController: ...@@ -81,20 +81,34 @@ class DataParallelController:
# Start data parallel workers # Start data parallel workers
base_gpu_id = 0 base_gpu_id = 0
self.workers = [] self.workers = []
scheduler_pipe_readers = []
for dp_rank in range(server_args.dp_size): for dp_rank in range(server_args.dp_size):
tmp_port_args = PortArgs.init_new(server_args) tmp_port_args = PortArgs.init_new(server_args)
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
send_to = self.launch_tensor_parallel_group( if server_args.enable_dp_attention:
server_args, # Share workers for DP and TP
tmp_port_args, send_to, reader = self.launch_tensor_parallel_process(
base_gpu_id, server_args,
dp_rank, tmp_port_args,
) base_gpu_id,
dp_rank,
)
base_gpu_id += 1
scheduler_pipe_readers.append(reader)
else:
send_to = self.launch_tensor_parallel_group(
server_args,
tmp_port_args,
base_gpu_id,
dp_rank,
)
base_gpu_id += server_args.tp_size
self.workers.append(send_to) self.workers.append(send_to)
base_gpu_id += server_args.tp_size
for reader in scheduler_pipe_readers:
reader.recv()
def launch_tensor_parallel_group( def launch_tensor_parallel_group(
self, self,
...@@ -132,6 +146,27 @@ class DataParallelController: ...@@ -132,6 +146,27 @@ class DataParallelController:
return send_to return send_to
def launch_tensor_parallel_process(
self,
server_args: ServerArgs,
port_args: PortArgs,
base_gpu_id: int,
dp_rank: int,
):
reader, writer = mp.Pipe(duplex=False)
gpu_id = base_gpu_id
tp_rank = dp_rank
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
)
proc.start()
send_to = get_zmq_socket(
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
)
return send_to, reader
def round_robin_scheduler(self, req): def round_robin_scheduler(self, req):
self.workers[self.round_robin_counter].send_pyobj(req) self.workers[self.round_robin_counter].send_pyobj(req)
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
......
...@@ -56,6 +56,7 @@ global_server_args_dict = { ...@@ -56,6 +56,7 @@ global_server_args_dict = {
"disable_mla": ServerArgs.disable_mla, "disable_mla": ServerArgs.disable_mla,
"torchao_config": ServerArgs.torchao_config, "torchao_config": ServerArgs.torchao_config,
"disable_nan_detection": ServerArgs.disable_nan_detection, "disable_nan_detection": ServerArgs.disable_nan_detection,
"enable_dp_attention": ServerArgs.enable_dp_attention,
} }
...@@ -450,6 +451,9 @@ class ScheduleBatch: ...@@ -450,6 +451,9 @@ class ScheduleBatch:
# The sum of all sequence lengths # The sum of all sequence lengths
seq_lens_sum: int = None seq_lens_sum: int = None
# For DP attention
global_num_tokens: Optional[List[int]] = None
# For processing logprobs # For processing logprobs
return_logprob: bool = False return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None top_logprobs_nums: Optional[List[int]] = None
...@@ -858,6 +862,16 @@ class ScheduleBatch: ...@@ -858,6 +862,16 @@ class ScheduleBatch:
# Reset the encoder cached status # Reset the encoder cached status
self.encoder_cached = [True] * len(self.reqs) self.encoder_cached = [True] * len(self.reqs)
def prepare_for_idle(self):
self.forward_mode = ForwardMode.IDLE
self.input_ids = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.seq_lens = torch.empty(0, dtype=torch.int32).to(
self.device, non_blocking=True
)
self.extend_num_tokens = 0
def prepare_for_decode(self, enable_overlap: bool = False): def prepare_for_decode(self, enable_overlap: bool = False):
self.forward_mode = ForwardMode.DECODE self.forward_mode = ForwardMode.DECODE
...@@ -969,17 +983,18 @@ class ScheduleBatch: ...@@ -969,17 +983,18 @@ class ScheduleBatch:
self.has_grammar = self.has_grammar or other.has_grammar self.has_grammar = self.has_grammar or other.has_grammar
def get_model_worker_batch(self): def get_model_worker_batch(self):
if self.forward_mode.is_decode(): if self.forward_mode.is_decode() or self.forward_mode.is_idle():
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
else: else:
extend_seq_lens = self.extend_lens extend_seq_lens = self.extend_lens
extend_prefix_lens = self.prefix_lens extend_prefix_lens = self.prefix_lens
extend_logprob_start_lens = self.extend_logprob_start_lens extend_logprob_start_lens = self.extend_logprob_start_lens
if self.has_grammar: if self.sampling_info is not None:
self.sampling_info.grammars = [req.grammar for req in self.reqs] if self.has_grammar:
else: self.sampling_info.grammars = [req.grammar for req in self.reqs]
self.sampling_info.grammars = None else:
self.sampling_info.grammars = None
global bid global bid
bid += 1 bid += 1
...@@ -995,6 +1010,7 @@ class ScheduleBatch: ...@@ -995,6 +1010,7 @@ class ScheduleBatch:
req_to_token_pool_records=self.req_to_token_pool.get_write_records(), req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
return_logprob=self.return_logprob, return_logprob=self.return_logprob,
top_logprobs_nums=self.top_logprobs_nums, top_logprobs_nums=self.top_logprobs_nums,
global_num_tokens=self.global_num_tokens,
extend_num_tokens=self.extend_num_tokens, extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens, extend_seq_lens=extend_seq_lens,
extend_prefix_lens=extend_prefix_lens, extend_prefix_lens=extend_prefix_lens,
...@@ -1051,6 +1067,9 @@ class ModelWorkerBatch: ...@@ -1051,6 +1067,9 @@ class ModelWorkerBatch:
return_logprob: bool return_logprob: bool
top_logprobs_nums: Optional[List[int]] top_logprobs_nums: Optional[List[int]]
# For DP attention
global_num_tokens: Optional[List[int]]
# For extend # For extend
extend_num_tokens: Optional[int] extend_num_tokens: Optional[int]
extend_seq_lens: Optional[List[int]] extend_seq_lens: Optional[List[int]]
......
...@@ -110,7 +110,7 @@ class Scheduler: ...@@ -110,7 +110,7 @@ class Scheduler:
# Init inter-process communication # Init inter-process communication
context = zmq.Context(2) context = zmq.Context(2)
if self.tp_rank == 0: if self.tp_rank == 0 or self.server_args.enable_dp_attention:
self.recv_from_tokenizer = get_zmq_socket( self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name context, zmq.PULL, port_args.scheduler_input_ipc_name
) )
...@@ -347,6 +347,10 @@ class Scheduler: ...@@ -347,6 +347,10 @@ class Scheduler:
self.process_input_requests(recv_reqs) self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run() batch = self.get_next_batch_to_run()
if self.server_args.enable_dp_attention:
batch = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch self.cur_batch = batch
if batch: if batch:
...@@ -361,6 +365,8 @@ class Scheduler: ...@@ -361,6 +365,8 @@ class Scheduler:
self.update_running_batch() self.update_running_batch()
if not self.running_batch: if not self.running_batch:
break break
if self.server_args.enable_dp_attention:
batch = self.prepare_dp_attn_batch(batch)
result = self.run_batch(batch) result = self.run_batch(batch)
self.process_batch_result(batch, result) self.process_batch_result(batch, result)
else: else:
...@@ -396,8 +402,48 @@ class Scheduler: ...@@ -396,8 +402,48 @@ class Scheduler:
self.last_batch = batch self.last_batch = batch
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
else:
num_tokens = local_batch.extend_num_tokens
local_num_tokens = torch.tensor(
num_tokens, dtype=torch.int64, device=self.device
)
global_num_tokens = torch.empty(
self.tp_size, dtype=torch.int64, device=self.device
)
torch.distributed.all_gather_into_tensor(
global_num_tokens,
local_num_tokens,
group=self.tp_worker.get_tp_device_group(),
)
if local_batch is None and global_num_tokens.max().item() > 0:
local_batch = self.get_idle_batch()
if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist()
return local_batch
def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(
[],
self.req_to_token_pool,
self.token_to_kv_pool,
self.tree_cache,
self.model_config,
)
idle_batch.prepare_for_idle()
return idle_batch
def recv_requests(self): def recv_requests(self):
if self.tp_rank == 0: if self.tp_rank == 0 or self.server_args.enable_dp_attention:
recv_reqs = [] recv_reqs = []
while True: while True:
...@@ -409,7 +455,7 @@ class Scheduler: ...@@ -409,7 +455,7 @@ class Scheduler:
else: else:
recv_reqs = None recv_reqs = None
if self.tp_size != 1: if self.tp_size != 1 and not self.server_args.enable_dp_attention:
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group) recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
return recv_reqs return recv_reqs
...@@ -812,6 +858,10 @@ class Scheduler: ...@@ -812,6 +858,10 @@ class Scheduler:
logits_output, next_token_ids = self.tp_worker.forward_batch_generation( logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch model_worker_batch
) )
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch()
self.tp_worker.forward_batch_idle(model_worker_batch)
return
else: else:
logits_output = None logits_output = None
if self.skip_tokenizer_init: if self.skip_tokenizer_init:
...@@ -830,6 +880,8 @@ class Scheduler: ...@@ -830,6 +880,8 @@ class Scheduler:
return ret return ret
def process_batch_result(self, batch: ScheduleBatch, result): def process_batch_result(self, batch: ScheduleBatch, result):
if batch.forward_mode.is_idle():
return
if batch.forward_mode.is_decode(): if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result) self.process_batch_result_decode(batch, result)
if batch.is_empty(): if batch.is_empty():
......
...@@ -128,12 +128,19 @@ class TpModelWorker: ...@@ -128,12 +128,19 @@ class TpModelWorker:
def get_tp_cpu_group(self): def get_tp_cpu_group(self):
return self.model_runner.tp_group.cpu_group return self.model_runner.tp_group.cpu_group
def get_tp_device_group(self):
return self.model_runner.tp_group.device_group
def get_memory_pool(self): def get_memory_pool(self):
return ( return (
self.model_runner.req_to_token_pool, self.model_runner.req_to_token_pool,
self.model_runner.token_to_kv_pool, self.model_runner.token_to_kv_pool,
) )
def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
self.model_runner.forward(forward_batch)
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch) logits_output = self.model_runner.forward(forward_batch)
......
...@@ -88,6 +88,9 @@ class TpModelWorkerClient: ...@@ -88,6 +88,9 @@ class TpModelWorkerClient:
def get_tp_cpu_group(self): def get_tp_cpu_group(self):
return self.worker.get_tp_cpu_group() return self.worker.get_tp_cpu_group()
def get_tp_device_group(self):
return self.worker.get_tp_device_group()
def get_memory_pool(self): def get_memory_pool(self):
return ( return (
self.worker.model_runner.req_to_token_pool, self.worker.model_runner.req_to_token_pool,
......
...@@ -56,6 +56,8 @@ class ForwardMode(IntEnum): ...@@ -56,6 +56,8 @@ class ForwardMode(IntEnum):
DECODE = auto() DECODE = auto()
# Contains both EXTEND and DECODE. # Contains both EXTEND and DECODE.
MIXED = auto() MIXED = auto()
# No sequence to forward. For data parallel attention, some workers wil be IDLE if no sequence allocated.
IDLE = auto()
def is_prefill(self): def is_prefill(self):
return self == ForwardMode.PREFILL return self == ForwardMode.PREFILL
...@@ -69,6 +71,9 @@ class ForwardMode(IntEnum): ...@@ -69,6 +71,9 @@ class ForwardMode(IntEnum):
def is_mixed(self): def is_mixed(self):
return self == ForwardMode.MIXED return self == ForwardMode.MIXED
def is_idle(self):
return self == ForwardMode.IDLE
@dataclass @dataclass
class ForwardBatch: class ForwardBatch:
...@@ -128,6 +133,10 @@ class ForwardBatch: ...@@ -128,6 +133,10 @@ class ForwardBatch:
# For Qwen2-VL # For Qwen2-VL
mrope_positions: torch.Tensor = None mrope_positions: torch.Tensor = None
# For DP attention
global_num_tokens: Optional[List[int]] = None
gathered_buffer: Optional[torch.Tensor] = None
def compute_mrope_positions( def compute_mrope_positions(
self, model_runner: ModelRunner, batch: ModelWorkerBatch self, model_runner: ModelRunner, batch: ModelWorkerBatch
): ):
...@@ -209,10 +218,22 @@ class ForwardBatch: ...@@ -209,10 +218,22 @@ class ForwardBatch:
seq_lens_sum=batch.seq_lens_sum, seq_lens_sum=batch.seq_lens_sum,
return_logprob=batch.return_logprob, return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
global_num_tokens=batch.global_num_tokens,
lora_paths=batch.lora_paths, lora_paths=batch.lora_paths,
sampling_info=batch.sampling_info, sampling_info=batch.sampling_info,
) )
if ret.global_num_tokens is not None:
max_len = max(ret.global_num_tokens)
ret.gathered_buffer = torch.zeros(
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
dtype=model_runner.dtype,
device=device,
)
if ret.forward_mode.is_idle():
return ret
# Init position information # Init position information
if not ret.forward_mode.is_decode(): if not ret.forward_mode.is_decode():
ret.positions = torch.concat( ret.positions = torch.concat(
......
...@@ -141,6 +141,7 @@ class ModelRunner: ...@@ -141,6 +141,7 @@ class ModelRunner:
"torchao_config": server_args.torchao_config, "torchao_config": server_args.torchao_config,
"disable_penalizer": server_args.disable_penalizer, "disable_penalizer": server_args.disable_penalizer,
"disable_nan_detection": server_args.disable_nan_detection, "disable_nan_detection": server_args.disable_nan_detection,
"enable_dp_attention": server_args.enable_dp_attention,
} }
) )
...@@ -592,11 +593,18 @@ class ModelRunner: ...@@ -592,11 +593,18 @@ class ModelRunner:
get_embedding=True, get_embedding=True,
) )
def forward_idle(self, forward_batch: ForwardBatch):
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput: def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
return self.forward_decode(forward_batch) return self.forward_decode(forward_batch)
elif forward_batch.forward_mode.is_extend(): elif forward_batch.forward_mode.is_extend():
return self.forward_extend(forward_batch) return self.forward_extend(forward_batch)
elif forward_batch.forward_mode.is_idle():
return self.forward_idle(forward_batch)
else: else:
raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}") raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
......
...@@ -22,7 +22,9 @@ import torch ...@@ -22,7 +22,9 @@ import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import ( from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
...@@ -338,6 +340,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -338,6 +340,7 @@ class DeepseekV2AttentionMLA(nn.Module):
cache_config=None, cache_config=None,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
layer_id=None, layer_id=None,
use_dp=False,
) -> None: ) -> None:
super().__init__() super().__init__()
self.layer_id = layer_id self.layer_id = layer_id
...@@ -351,29 +354,80 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -351,29 +354,80 @@ class DeepseekV2AttentionMLA(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0 assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size self.num_local_heads = num_heads if use_dp else num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5 self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None: if use_dp:
self.q_a_proj = ReplicatedLinear( # For data parallel attention
self.hidden_size, if self.q_lora_rank is not None:
self.q_lora_rank, self.q_a_proj = ReplicatedLinear(
self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ReplicatedLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
)
else:
self.q_proj = ReplicatedLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
)
self.kv_b_proj = ReplicatedLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
) )
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) # O projection.
self.q_b_proj = ColumnParallelLinear( self.o_proj = ReplicatedLinear(
q_lora_rank, self.num_heads * self.v_head_dim,
self.num_heads * self.qk_head_dim, self.hidden_size,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
) )
else: else:
self.q_proj = ColumnParallelLinear( # For tensor parallel attention
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
)
# O projection.
self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size, self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
) )
...@@ -385,19 +439,6 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -385,19 +439,6 @@ class DeepseekV2AttentionMLA(nn.Module):
quant_config=quant_config, quant_config=quant_config,
) )
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
)
# O projection.
self.o_proj = RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
)
rope_scaling["rope_type"] = "deepseek_yarn" rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope( self.rotary_emb = get_rope(
qk_rope_head_dim, qk_rope_head_dim,
...@@ -491,6 +532,36 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -491,6 +532,36 @@ class DeepseekV2AttentionMLA(nn.Module):
return output return output
def all_gather(
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
):
if world_size == 1:
return input_tensor
all_lens = forward_batch.global_num_tokens
max_len = max(forward_batch.global_num_tokens)
padded_tensor = torch.nn.functional.pad(
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
)
torch.distributed.all_gather_into_tensor(
forward_batch.gathered_buffer, padded_tensor, group=group
)
gathered_tensors = torch.concat(
[
forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
for i in range(world_size)
]
)
start_index = 0 if rank == 0 else sum(all_lens[:rank])
end_index = start_index + all_lens[rank]
return gathered_tensors, start_index, end_index
class DeepseekV2DecoderLayer(nn.Module): class DeepseekV2DecoderLayer(nn.Module):
def __init__( def __init__(
...@@ -505,6 +576,14 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -505,6 +576,14 @@ class DeepseekV2DecoderLayer(nn.Module):
rope_theta = getattr(config, "rope_theta", 10000) rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None) rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192) max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.enable_dp_attention = (
not global_server_args_dict["disable_mla"]
and global_server_args_dict["enable_dp_attention"]
)
if self.enable_dp_attention:
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_group = get_tp_group().device_group
if not global_server_args_dict["disable_mla"]: if not global_server_args_dict["disable_mla"]:
self.self_attn = DeepseekV2AttentionMLA( self.self_attn = DeepseekV2AttentionMLA(
config=config, config=config,
...@@ -523,6 +602,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -523,6 +602,7 @@ class DeepseekV2DecoderLayer(nn.Module):
cache_config=cache_config, cache_config=cache_config,
quant_config=quant_config, quant_config=quant_config,
layer_id=layer_id, layer_id=layer_id,
use_dp=self.enable_dp_attention,
) )
else: else:
self.self_attn = DeepseekV2Attention( self.self_attn = DeepseekV2Attention(
...@@ -569,20 +649,32 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -569,20 +649,32 @@ class DeepseekV2DecoderLayer(nn.Module):
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
if residual is None: if not forward_batch.forward_mode.is_idle():
residual = hidden_states if residual is None:
hidden_states = self.input_layernorm(hidden_states) residual = hidden_states
else: hidden_states = self.input_layernorm(hidden_states)
hidden_states, residual = self.input_layernorm(hidden_states, residual) else:
hidden_states = self.self_attn( hidden_states, residual = self.input_layernorm(hidden_states, residual)
positions=positions,
hidden_states=hidden_states, hidden_states = self.self_attn(
forward_batch=forward_batch, positions=positions,
) hidden_states=hidden_states,
forward_batch=forward_batch,
)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
# Fully Connected # Fully Connected
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) if self.enable_dp_attention:
hidden_states = self.mlp(hidden_states) hidden_states, start_idx, end_idx = all_gather(
hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
)
hidden_states = self.mlp(hidden_states)
hidden_states = hidden_states[start_idx:end_idx]
else:
hidden_states = self.mlp(hidden_states)
return hidden_states, residual return hidden_states, residual
...@@ -603,6 +695,7 @@ class DeepseekV2Model(nn.Module): ...@@ -603,6 +695,7 @@ class DeepseekV2Model(nn.Module):
self.embed_tokens = VocabParallelEmbedding( self.embed_tokens = VocabParallelEmbedding(
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
) )
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[ [
...@@ -630,7 +723,8 @@ class DeepseekV2Model(nn.Module): ...@@ -630,7 +723,8 @@ class DeepseekV2Model(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual positions, hidden_states, forward_batch, residual
) )
hidden_states, _ = self.norm(hidden_states, residual) if not forward_batch.forward_mode.is_idle():
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -646,10 +740,18 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -646,10 +740,18 @@ class DeepseekV2ForCausalLM(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.model = DeepseekV2Model(config, cache_config, quant_config) self.model = DeepseekV2Model(config, cache_config, quant_config)
self.lm_head = ParallelLMHead( if global_server_args_dict["enable_dp_attention"]:
config.vocab_size, config.hidden_size, quant_config=quant_config self.lm_head = ReplicatedLinear(
) config.hidden_size,
self.logits_processor = LogitsProcessor(config) config.vocab_size,
bias=False,
)
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
else:
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -659,9 +761,10 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -659,9 +761,10 @@ class DeepseekV2ForCausalLM(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, forward_batch) hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor( if not forward_batch.forward_mode.is_idle():
input_ids, hidden_states, self.lm_head.weight, forward_batch return self.logits_processor(
) input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -129,6 +129,7 @@ class ServerArgs: ...@@ -129,6 +129,7 @@ class ServerArgs:
disable_nan_detection: bool = False disable_nan_detection: bool = False
enable_overlap_schedule: bool = False enable_overlap_schedule: bool = False
enable_mixed_chunk: bool = False enable_mixed_chunk: bool = False
enable_dp_attention: bool = False
enable_torch_compile: bool = False enable_torch_compile: bool = False
torch_compile_max_bs: int = 32 torch_compile_max_bs: int = 32
cuda_graph_max_bs: int = 160 cuda_graph_max_bs: int = 160
...@@ -203,6 +204,16 @@ class ServerArgs: ...@@ -203,6 +204,16 @@ class ServerArgs:
if self.sampling_backend is None: if self.sampling_backend is None:
self.sampling_backend = "flashinfer" self.sampling_backend = "flashinfer"
if self.enable_dp_attention:
self.dp_size = self.tp_size
self.chunked_prefill_size = self.chunked_prefill_size // 2
self.disable_cuda_graph = True
self.enable_overlap_schedule = False
logger.warning(
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. "
"The CUDA graph is disabled."
)
if self.enable_overlap_schedule: if self.enable_overlap_schedule:
logger.warning( logger.warning(
"Overlap scheduler mode is enabled. This is an experimental feature. " "Overlap scheduler mode is enabled. This is an experimental feature. "
...@@ -669,6 +680,11 @@ class ServerArgs: ...@@ -669,6 +680,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Enabling mixing prefill and decode in a batch when using chunked prefill.", help="Enabling mixing prefill and decode in a batch when using chunked prefill.",
) )
parser.add_argument(
"--enable-dp-attention",
action="store_true",
help="Enabling data parallelism for attention and tensor parallelism for FFN. The dp size should be equal to the tp size. Currently only DeepSeek-V2 is supported.",
)
parser.add_argument( parser.add_argument(
"--enable-torch-compile", "--enable-torch-compile",
action="store_true", action="store_true",
......
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestDPAttention(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--trust-remote-code",
"--tp",
"2",
"--dp",
"2",
"--enable-dp-attention",
],
)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True)
def test_mmlu(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.5
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.8
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment