Unverified Commit f44d1439 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Support target model verification in the attention backend (#2678)


Co-authored-by: default avataryukavio <kavioyu@gmail.com>
parent b6b57fc2
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import TYPE_CHECKING, Optional
import torch import torch
from sglang.srt.layers.radix_attention import RadixAttention if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.spec_info import SpecInfo
class AttentionBackend(ABC): class AttentionBackend(ABC):
...@@ -22,9 +26,12 @@ class AttentionBackend(ABC): ...@@ -22,9 +26,12 @@ class AttentionBackend(ABC):
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, self,
bs: int, bs: int,
num_token: int,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor] = None, encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
): ):
"""Init the metadata for a forward pass for capturing a cuda graph.""" """Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
...@@ -35,7 +42,9 @@ class AttentionBackend(ABC): ...@@ -35,7 +42,9 @@ class AttentionBackend(ABC):
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor] = None, encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
): ):
"""Init the metadata for a forward pass for replying a cuda graph.""" """Init the metadata for a forward pass for replying a cuda graph."""
raise NotImplementedError() raise NotImplementedError()
......
...@@ -3,7 +3,6 @@ from __future__ import annotations ...@@ -3,7 +3,6 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import torch import torch
import torch.nn as nn
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
...@@ -52,8 +51,6 @@ class DoubleSparseAttnBackend(AttentionBackend): ...@@ -52,8 +51,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
self.forward_metadata = None self.forward_metadata = None
self.cuda_graph_max_seq_len = model_runner.model_config.context_len
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init auxiliary variables for triton attention backend.""" """Init auxiliary variables for triton attention backend."""
...@@ -115,55 +112,6 @@ class DoubleSparseAttnBackend(AttentionBackend): ...@@ -115,55 +112,6 @@ class DoubleSparseAttnBackend(AttentionBackend):
ds_req_to_token, ds_req_to_token,
) )
def init_cuda_graph_state(self, max_bs: int):
# TODO(Andy): Support CUDA graph for double sparse attention
raise ValueError(
"Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len
self.cuda_graph_start_loc = torch.zeros(
(max_bs,), dtype=torch.int32, device="cuda"
)
self.cuda_graph_attn_logits = torch.empty(
(
self.num_head,
self.cuda_graph_max_total_num_tokens,
),
dtype=self.reduce_dtype,
device="cuda",
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens=None,
):
# NOTE: encoder_lens expected to be zeros or None
self.forward_metadata = (
self.cuda_graph_start_loc,
self.cuda_graph_attn_logits,
self.cuda_graph_max_seq_len,
None,
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens=None,
):
# NOTE: encoder_lens expected to be zeros or None
self.cuda_graph_start_loc.zero_()
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def forward_extend( def forward_extend(
self, self,
q, q,
......
...@@ -10,7 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an ...@@ -10,7 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from typing import TYPE_CHECKING, List, Union from typing import TYPE_CHECKING, List, Optional, Union
import torch import torch
import triton import triton
...@@ -18,12 +18,13 @@ import triton.language as tl ...@@ -18,12 +18,13 @@ import triton.language as tl
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import is_flashinfer_available from sglang.srt.utils import is_flashinfer_available
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
if is_flashinfer_available(): if is_flashinfer_available():
from flashinfer import ( from flashinfer import (
...@@ -113,11 +114,15 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -113,11 +114,15 @@ class FlashInferAttnBackend(AttentionBackend):
# Two wrappers: one for sliding window attention and one for full attention. # Two wrappers: one for sliding window attention and one for full attention.
# Using two wrappers is unnecessary in the current PR, but are prepared for future PRs # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs
self.prefill_wrappers_paged = [] self.prefill_wrappers_paged = []
self.prefill_wrappers_verify = []
self.decode_wrappers = [] self.decode_wrappers = []
for _ in range(self.num_wrappers): for _ in range(self.num_wrappers):
self.prefill_wrappers_paged.append( self.prefill_wrappers_paged.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
) )
self.prefill_wrappers_verify.append(
BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
)
self.decode_wrappers.append( self.decode_wrappers.append(
BatchDecodeWithPagedKVCacheWrapper( BatchDecodeWithPagedKVCacheWrapper(
self.workspace_buffer, self.workspace_buffer,
...@@ -135,6 +140,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -135,6 +140,7 @@ class FlashInferAttnBackend(AttentionBackend):
# Other metadata # Other metadata
self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None
self.decode_cuda_graph_metadata = {} self.decode_cuda_graph_metadata = {}
self.prefill_cuda_graph_metadata = {}
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
if forward_batch.forward_mode.is_decode(): if forward_batch.forward_mode.is_decode():
...@@ -144,8 +150,37 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -144,8 +150,37 @@ class FlashInferAttnBackend(AttentionBackend):
forward_batch.seq_lens_sum, forward_batch.seq_lens_sum,
decode_wrappers=self.decode_wrappers, decode_wrappers=self.decode_wrappers,
encoder_lens=forward_batch.encoder_lens, encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
) )
self.forward_metadata = DecodeMetadata(self.decode_wrappers) self.forward_metadata = DecodeMetadata(self.decode_wrappers)
elif forward_batch.forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
prefill_wrappers=self.prefill_wrappers_paged,
use_ragged=False,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrappers_paged, False, False
)
elif forward_batch.forward_mode.is_target_verify():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
prefill_wrappers=self.prefill_wrappers_verify,
use_ragged=False,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = PrefillMetadata(
self.prefill_wrappers_verify, False, False
)
else: else:
prefix_lens = forward_batch.extend_prefix_lens prefix_lens = forward_batch.extend_prefix_lens
...@@ -165,6 +200,7 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -165,6 +200,7 @@ class FlashInferAttnBackend(AttentionBackend):
prefill_wrappers=self.prefill_wrappers_paged, prefill_wrappers=self.prefill_wrappers_paged,
use_ragged=use_ragged, use_ragged=use_ragged,
encoder_lens=forward_batch.encoder_lens, encoder_lens=forward_batch.encoder_lens,
spec_info=None,
) )
self.forward_metadata = PrefillMetadata( self.forward_metadata = PrefillMetadata(
self.prefill_wrappers_paged, use_ragged, extend_no_prefix self.prefill_wrappers_paged, use_ragged, extend_no_prefix
...@@ -180,37 +216,80 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -180,37 +216,80 @@ class FlashInferAttnBackend(AttentionBackend):
cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1) cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1)
] ]
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
dtype=torch.uint8,
device="cuda",
)
self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr]
self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr]
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, self,
bs: int, bs: int,
num_token: int,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens: torch.Tensor = None, encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
): ):
decode_wrappers = [] if forward_mode.is_decode():
for i in range(self.num_wrappers): decode_wrappers = []
decode_wrappers.append( for i in range(self.num_wrappers):
BatchDecodeWithPagedKVCacheWrapper( decode_wrappers.append(
self.workspace_buffer, BatchDecodeWithPagedKVCacheWrapper(
"NHD", self.workspace_buffer,
use_cuda_graph=True, "NHD",
use_tensor_cores=self.decode_use_tensor_cores, use_cuda_graph=True,
paged_kv_indptr_buffer=self.kv_indptr[i][: bs + 1], use_tensor_cores=self.decode_use_tensor_cores,
paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], paged_kv_indptr_buffer=self.kv_indptr[i][: num_token + 1],
paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs], paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buffer=self.kv_last_page_len[:num_token],
)
) )
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_decode.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
decode_wrappers=decode_wrappers,
encoder_lens=encoder_lens,
spec_info=spec_info,
) )
self.decode_cuda_graph_metadata[bs] = decode_wrappers
seq_lens_sum = seq_lens.sum().item() self.forward_metadata = DecodeMetadata(decode_wrappers)
self.indices_updater_decode.update( elif forward_mode.is_target_verify():
req_pool_indices, prefill_wrappers = []
seq_lens, for i in range(self.num_wrappers):
seq_lens_sum, prefill_wrappers.append(
decode_wrappers=decode_wrappers, BatchPrefillWithPagedKVCacheWrapper(
encoder_lens=encoder_lens, self.workspace_buffer,
) "NHD",
self.decode_cuda_graph_metadata[bs] = decode_wrappers use_cuda_graph=True,
self.forward_metadata = DecodeMetadata(decode_wrappers) qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1],
paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
custom_mask_buf=self.cuda_graph_custom_mask,
qk_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1],
)
)
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
prefill_wrappers=prefill_wrappers,
use_ragged=False,
encoder_lens=encoder_lens,
spec_info=spec_info,
)
self.prefill_cuda_graph_metadata[bs] = prefill_wrappers
self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")
def init_forward_metadata_replay_cuda_graph( def init_forward_metadata_replay_cuda_graph(
self, self,
...@@ -218,24 +297,41 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -218,24 +297,41 @@ class FlashInferAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens: torch.Tensor = None, encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
): ):
self.indices_updater_decode.update( if forward_mode.is_decode():
req_pool_indices[:bs], self.indices_updater_decode.update(
seq_lens[:bs], req_pool_indices[:bs],
seq_lens_sum, seq_lens[:bs],
decode_wrappers=self.decode_cuda_graph_metadata[bs], seq_lens_sum,
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, decode_wrappers=self.decode_cuda_graph_metadata[bs],
) encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
spec_info=spec_info,
)
elif forward_mode.is_target_verify():
self.indices_updater_prefill.update(
req_pool_indices[:bs],
seq_lens[:bs],
seq_lens_sum,
prefix_lens=None,
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
use_ragged=False,
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
spec_info=spec_info,
)
else:
raise ValueError("Invalid forward mode")
def get_cuda_graph_seq_len_fill_value(self): def get_cuda_graph_seq_len_fill_value(self):
return 0 return 0
def forward_extend( def forward_extend(
self, self,
q, q: torch.Tensor,
k, k: torch.Tensor,
v, v: torch.Tensor,
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
...@@ -293,9 +389,9 @@ class FlashInferAttnBackend(AttentionBackend): ...@@ -293,9 +389,9 @@ class FlashInferAttnBackend(AttentionBackend):
def forward_decode( def forward_decode(
self, self,
q, q: torch.Tensor,
k, k: torch.Tensor,
v, v: torch.Tensor,
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
...@@ -348,7 +444,6 @@ class FlashInferIndicesUpdaterDecode: ...@@ -348,7 +444,6 @@ class FlashInferIndicesUpdaterDecode:
self.data_type = model_runner.kv_cache_dtype self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype self.q_data_type = model_runner.dtype
self.sliding_window_size = model_runner.sliding_window_size self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend self.attn_backend = attn_backend
# Buffers and wrappers # Buffers and wrappers
...@@ -371,7 +466,8 @@ class FlashInferIndicesUpdaterDecode: ...@@ -371,7 +466,8 @@ class FlashInferIndicesUpdaterDecode:
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError() raise NotImplementedError()
...@@ -382,7 +478,8 @@ class FlashInferIndicesUpdaterDecode: ...@@ -382,7 +478,8 @@ class FlashInferIndicesUpdaterDecode:
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
): ):
decode_wrappers = decode_wrappers or self.decode_wrappers decode_wrappers = decode_wrappers or self.decode_wrappers
self.call_begin_forward( self.call_begin_forward(
...@@ -392,6 +489,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -392,6 +489,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum, seq_lens_sum,
self.kv_indptr[0], self.kv_indptr[0],
None, None,
spec_info,
) )
def update_sliding_window( def update_sliding_window(
...@@ -400,7 +498,8 @@ class FlashInferIndicesUpdaterDecode: ...@@ -400,7 +498,8 @@ class FlashInferIndicesUpdaterDecode:
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
...@@ -424,6 +523,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -424,6 +523,7 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum_tmp, paged_kernel_lens_sum_tmp,
self.kv_indptr[wrapper_id], self.kv_indptr[wrapper_id],
kv_start_idx_tmp, kv_start_idx_tmp,
spec_info,
) )
def update_cross_attention( def update_cross_attention(
...@@ -432,7 +532,8 @@ class FlashInferIndicesUpdaterDecode: ...@@ -432,7 +532,8 @@ class FlashInferIndicesUpdaterDecode:
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
...@@ -452,6 +553,7 @@ class FlashInferIndicesUpdaterDecode: ...@@ -452,6 +553,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum, seq_lens_sum,
self.kv_indptr[wrapper_id], self.kv_indptr[wrapper_id],
kv_start_idx, kv_start_idx,
spec_info,
) )
def call_begin_forward( def call_begin_forward(
...@@ -462,23 +564,30 @@ class FlashInferIndicesUpdaterDecode: ...@@ -462,23 +564,30 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum: int, paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor, kv_start_idx: torch.Tensor,
spec_info: Optional[SpecInfo],
): ):
bs = len(req_pool_indices) if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) bs = len(req_pool_indices)
kv_indptr = kv_indptr[: bs + 1] kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indices = torch.empty( kv_indptr = kv_indptr[: bs + 1]
paged_kernel_lens_sum, dtype=torch.int32, device="cuda" kv_indices = torch.empty(
) paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)]( create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token, self.req_to_token,
req_pool_indices, req_pool_indices,
paged_kernel_lens, paged_kernel_lens,
kv_indptr, kv_indptr,
kv_start_idx, kv_start_idx,
kv_indices, kv_indices,
self.req_to_token.shape[1], self.req_to_token.shape[1],
) )
else:
bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode(
req_pool_indices,
paged_kernel_lens,
self.req_to_token,
)
wrapper.end_forward() wrapper.end_forward()
wrapper.begin_forward( wrapper.begin_forward(
...@@ -507,7 +616,6 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -507,7 +616,6 @@ class FlashInferIndicesUpdaterPrefill:
self.data_type = model_runner.kv_cache_dtype self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype self.q_data_type = model_runner.dtype
self.sliding_window_size = model_runner.sliding_window_size self.sliding_window_size = model_runner.sliding_window_size
self.attn_backend = attn_backend self.attn_backend = attn_backend
# Buffers and wrappers # Buffers and wrappers
...@@ -534,7 +642,8 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -534,7 +642,8 @@ class FlashInferIndicesUpdaterPrefill:
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
): ):
# Keep the signature for type checking. It will be assigned during runtime. # Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError() raise NotImplementedError()
...@@ -547,7 +656,8 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -547,7 +656,8 @@ class FlashInferIndicesUpdaterPrefill:
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
): ):
if use_ragged: if use_ragged:
paged_kernel_lens = prefix_lens paged_kernel_lens = prefix_lens
...@@ -568,6 +678,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -568,6 +678,7 @@ class FlashInferIndicesUpdaterPrefill:
self.kv_indptr[0], self.kv_indptr[0],
self.qo_indptr[0], self.qo_indptr[0],
use_ragged, use_ragged,
spec_info,
) )
def update_sliding_window( def update_sliding_window(
...@@ -578,7 +689,8 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -578,7 +689,8 @@ class FlashInferIndicesUpdaterPrefill:
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
...@@ -607,6 +719,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -607,6 +719,7 @@ class FlashInferIndicesUpdaterPrefill:
self.kv_indptr[wrapper_id], self.kv_indptr[wrapper_id],
self.qo_indptr[wrapper_id], self.qo_indptr[wrapper_id],
use_ragged, use_ragged,
spec_info,
) )
def update_cross_attention( def update_cross_attention(
...@@ -617,7 +730,8 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -617,7 +730,8 @@ class FlashInferIndicesUpdaterPrefill:
prefix_lens: torch.Tensor, prefix_lens: torch.Tensor,
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool, use_ragged: bool,
encoder_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
): ):
for wrapper_id in range(2): for wrapper_id in range(2):
if wrapper_id == 0: if wrapper_id == 0:
...@@ -643,6 +757,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -643,6 +757,7 @@ class FlashInferIndicesUpdaterPrefill:
self.kv_indptr[wrapper_id], self.kv_indptr[wrapper_id],
self.qo_indptr[wrapper_id], self.qo_indptr[wrapper_id],
use_ragged, use_ragged,
spec_info,
) )
def call_begin_forward( def call_begin_forward(
...@@ -658,25 +773,37 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -658,25 +773,37 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr: torch.Tensor, kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor, qo_indptr: torch.Tensor,
use_ragged: bool, use_ragged: bool,
spec_info: Optional[SpecInfo],
): ):
bs = len(req_pool_indices) bs = len(req_pool_indices)
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) if spec_info is None:
kv_indptr = kv_indptr[: bs + 1] # Normal extend
kv_indices = torch.empty( kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
paged_kernel_lens_sum, dtype=torch.int32, device="cuda" kv_indptr = kv_indptr[: bs + 1]
) kv_indices = torch.empty(
create_flashinfer_kv_indices_triton[(bs,)]( paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
self.req_to_token, )
req_pool_indices, create_flashinfer_kv_indices_triton[(bs,)](
paged_kernel_lens, self.req_to_token,
kv_indptr, req_pool_indices,
kv_start_idx, paged_kernel_lens,
kv_indices, kv_indptr,
self.req_to_token.shape[1], kv_start_idx,
) kv_indices,
self.req_to_token.shape[1],
)
qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1] qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
else:
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
self.req_to_token,
)
)
# extend part # extend part
if use_ragged: if use_ragged:
...@@ -702,6 +829,7 @@ class FlashInferIndicesUpdaterPrefill: ...@@ -702,6 +829,7 @@ class FlashInferIndicesUpdaterPrefill:
self.head_dim, self.head_dim,
1, 1,
q_data_type=self.q_data_type, q_data_type=self.q_data_type,
custom_mask=custom_mask,
) )
......
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING
import torch import torch
from torch.nn.functional import scaled_dot_product_attention from torch.nn.functional import scaled_dot_product_attention
...@@ -23,43 +23,6 @@ class TorchNativeAttnBackend(AttentionBackend): ...@@ -23,43 +23,6 @@ class TorchNativeAttnBackend(AttentionBackend):
"""Init the metadata for a forward pass.""" """Init the metadata for a forward pass."""
pass pass
def init_cuda_graph_state(self, max_bs: int):
# TODO: Support CUDA graph
raise ValueError(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor] = None,
):
# TODO: Support CUDA graph
raise ValueError(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor] = None,
):
# TODO: Support CUDA graph
raise ValueError(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def get_cuda_graph_seq_len_fill_value(self):
# TODO: Support CUDA graph
raise ValueError(
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
)
def _run_sdpa_forward_extend( def _run_sdpa_forward_extend(
self, self,
query: torch.Tensor, query: torch.Tensor,
......
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
import torch import torch
from sglang.srt.layers.attention import AttentionBackend from sglang.srt.layers.attention import AttentionBackend
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
class TritonAttnBackend(AttentionBackend): class TritonAttnBackend(AttentionBackend):
...@@ -80,11 +81,17 @@ class TritonAttnBackend(AttentionBackend): ...@@ -80,11 +81,17 @@ class TritonAttnBackend(AttentionBackend):
def init_forward_metadata_capture_cuda_graph( def init_forward_metadata_capture_cuda_graph(
self, self,
bs: int, bs: int,
num_token: int,
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
encoder_lens=None, encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
): ):
# NOTE: encoder_lens expected to be zeros or None assert encoder_lens is None, "Not supported"
assert forward_mode.is_decode(), "Not supported"
assert spec_info is None, "Not supported"
self.forward_metadata = ( self.forward_metadata = (
self.cuda_graph_attn_logits, self.cuda_graph_attn_logits,
None, None,
...@@ -96,7 +103,9 @@ class TritonAttnBackend(AttentionBackend): ...@@ -96,7 +103,9 @@ class TritonAttnBackend(AttentionBackend):
req_pool_indices: torch.Tensor, req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor, seq_lens: torch.Tensor,
seq_lens_sum: int, seq_lens_sum: int,
encoder_lens=None, encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
): ):
# NOTE: encoder_lens expected to be zeros or None # NOTE: encoder_lens expected to be zeros or None
self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc.zero_()
...@@ -107,9 +116,9 @@ class TritonAttnBackend(AttentionBackend): ...@@ -107,9 +116,9 @@ class TritonAttnBackend(AttentionBackend):
def forward_extend( def forward_extend(
self, self,
q, q: torch.Tensor,
k, k: torch.Tensor,
v, v: torch.Tensor,
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
...@@ -146,9 +155,9 @@ class TritonAttnBackend(AttentionBackend): ...@@ -146,9 +155,9 @@ class TritonAttnBackend(AttentionBackend):
def forward_decode( def forward_decode(
self, self,
q, q: torch.Tensor,
k, k: torch.Tensor,
v, v: torch.Tensor,
layer: RadixAttention, layer: RadixAttention,
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
......
...@@ -25,14 +25,14 @@ from vllm.distributed import get_tensor_model_parallel_rank ...@@ -25,14 +25,14 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import graph_capture from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.logits_processor import ( from sglang.srt.layers.logits_processor import LogitsProcessorOutput
LogitsMetadata,
LogitsProcessor,
LogitsProcessorOutput,
)
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
from sglang.srt.layers.torchao_utils import save_gemlite_cache from sglang.srt.layers.torchao_utils import save_gemlite_cache
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
)
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -153,6 +153,10 @@ class CudaGraphRunner: ...@@ -153,6 +153,10 @@ class CudaGraphRunner:
if bs <= model_runner.req_to_token_pool.size if bs <= model_runner.req_to_token_pool.size
and bs <= model_runner.server_args.cuda_graph_max_bs and bs <= model_runner.server_args.cuda_graph_max_bs
] ]
self.capture_forward_mode = ForwardMode.DECODE
self.num_tokens_per_bs = 1
self.compile_bs = ( self.compile_bs = (
[ [
bs bs
...@@ -165,8 +169,8 @@ class CudaGraphRunner: ...@@ -165,8 +169,8 @@ class CudaGraphRunner:
# Attention backend # Attention backend
self.max_bs = max(self.capture_bs) self.max_bs = max(self.capture_bs)
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token)
self.seq_len_fill_value = ( self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
) )
...@@ -179,12 +183,13 @@ class CudaGraphRunner: ...@@ -179,12 +183,13 @@ class CudaGraphRunner:
# Common inputs # Common inputs
with torch.device("cuda"): with torch.device("cuda"):
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32) self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.seq_lens = torch.full( self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
) )
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32) self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32)
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32) self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
if self.is_encoder_decoder: if self.is_encoder_decoder:
...@@ -229,6 +234,9 @@ class CudaGraphRunner: ...@@ -229,6 +234,9 @@ class CudaGraphRunner:
self.model_runner.model.capture_mode = False self.model_runner.model.capture_mode = False
def can_run(self, forward_batch: ForwardBatch): def can_run(self, forward_batch: ForwardBatch):
if not forward_batch.forward_mode.is_cuda_graph():
return False
if self.enable_dp_attention: if self.enable_dp_attention:
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max( min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
forward_batch.global_num_tokens forward_batch.global_num_tokens
...@@ -258,12 +266,12 @@ class CudaGraphRunner: ...@@ -258,12 +266,12 @@ class CudaGraphRunner:
def capture(self): def capture(self):
with graph_capture() as graph_capture_context: with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream self.stream = graph_capture_context.stream
capture_bs = ( capture_range = (
tqdm.tqdm(self.capture_bs) tqdm.tqdm(self.capture_bs)
if get_tensor_model_parallel_rank() == 0 if get_tensor_model_parallel_rank() == 0
else self.capture_bs else self.capture_bs
) )
for bs in capture_bs: for bs in capture_range:
with patch_model( with patch_model(
self.model_runner.model, self.model_runner.model,
bs in self.compile_bs, bs in self.compile_bs,
...@@ -283,12 +291,15 @@ class CudaGraphRunner: ...@@ -283,12 +291,15 @@ class CudaGraphRunner:
def capture_one_batch_size(self, bs: int, forward: Callable): def capture_one_batch_size(self, bs: int, forward: Callable):
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
stream = self.stream stream = self.stream
num_token = bs * self.num_tokens_per_bs
# Common inputs # Common inputs
input_ids = self.input_ids[:bs] input_ids = self.input_ids[:num_token]
req_pool_indices = self.req_pool_indices[:bs] req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs] seq_lens = self.seq_lens[:bs]
out_cache_loc = self.out_cache_loc[:bs] out_cache_loc = self.out_cache_loc[:num_token]
positions = self.positions[:num_token]
if self.is_encoder_decoder: if self.is_encoder_decoder:
encoder_lens = self.encoder_lens[:bs] encoder_lens = self.encoder_lens[:bs]
else: else:
...@@ -304,37 +315,41 @@ class CudaGraphRunner: ...@@ -304,37 +315,41 @@ class CudaGraphRunner:
global_num_tokens = None global_num_tokens = None
gathered_buffer = None gathered_buffer = None
forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode,
batch_size=bs,
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens_sum,
encoder_lens=encoder_lens,
return_logprob=False,
top_logprobs_nums=[0] * num_token,
positions=positions,
global_num_tokens=global_num_tokens,
mrope_positions=mrope_positions,
gathered_buffer=gathered_buffer,
)
# Attention backend # Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs, bs,
num_token,
req_pool_indices, req_pool_indices,
seq_lens, seq_lens,
encoder_lens, encoder_lens,
forward_batch.forward_mode,
forward_batch.spec_info,
) )
# Run and capture # Run and capture
def run_once(): def run_once():
forward_batch = ForwardBatch(
forward_mode=ForwardMode.DECODE,
batch_size=bs,
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens_sum,
encoder_lens=encoder_lens,
return_logprob=False,
top_logprobs_nums=[0] * bs,
positions=clamp_position(seq_lens),
mrope_positions=mrope_positions,
global_num_tokens=global_num_tokens,
gathered_buffer=gathered_buffer,
)
logits_output = forward(input_ids, forward_batch.positions, forward_batch) logits_output = forward(input_ids, forward_batch.positions, forward_batch)
return logits_output.next_token_logits return logits_output.next_token_logits, logits_output.hidden_states
for _ in range(2): for _ in range(2):
torch.cuda.synchronize() torch.cuda.synchronize()
...@@ -360,6 +375,9 @@ class CudaGraphRunner: ...@@ -360,6 +375,9 @@ class CudaGraphRunner:
def replay(self, forward_batch: ForwardBatch): def replay(self, forward_batch: ForwardBatch):
assert forward_batch.out_cache_loc is not None assert forward_batch.out_cache_loc is not None
raw_bs = forward_batch.batch_size raw_bs = forward_batch.batch_size
# In normal decoding case, raw_bs == raw_num_token
# But in speculative decoding, raw_num_token is raw_bs * self.num_tokens_per_bs
raw_num_token = forward_batch.input_ids.numel()
# Pad # Pad
if self.enable_dp_attention: if self.enable_dp_attention:
...@@ -374,10 +392,13 @@ class CudaGraphRunner: ...@@ -374,10 +392,13 @@ class CudaGraphRunner:
self.out_cache_loc.zero_() self.out_cache_loc.zero_()
# Common inputs # Common inputs
self.input_ids[:raw_bs].copy_(forward_batch.input_ids) self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc) self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
positions = clamp_position(forward_batch.seq_lens)
self.positions[:raw_num_token].copy_(positions)
if self.is_encoder_decoder: if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None: if forward_batch.mrope_positions is not None:
...@@ -390,13 +411,18 @@ class CudaGraphRunner: ...@@ -390,13 +411,18 @@ class CudaGraphRunner:
self.seq_lens, self.seq_lens,
forward_batch.seq_lens_sum + (bs - raw_bs), forward_batch.seq_lens_sum + (bs - raw_bs),
self.encoder_lens, self.encoder_lens,
forward_batch.forward_mode,
forward_batch.spec_info,
) )
# Replay # Replay
self.graphs[bs].replay() self.graphs[bs].replay()
next_token_logits = self.output_buffers[bs][:raw_bs] next_token_logits, hidden_states = self.output_buffers[bs]
logits_output = LogitsProcessorOutput( logits_output = LogitsProcessorOutput(
next_token_logits=next_token_logits, next_token_logits=next_token_logits[:raw_num_token],
hidden_states=(
hidden_states[:raw_num_token] if hidden_states is not None else None
),
) )
return logits_output return logits_output
...@@ -96,7 +96,7 @@ class ForwardMode(IntEnum): ...@@ -96,7 +96,7 @@ class ForwardMode(IntEnum):
return self == ForwardMode.DRAFT_EXTEND return self == ForwardMode.DRAFT_EXTEND
def is_cuda_graph(self): def is_cuda_graph(self):
return self in (ForwardMode.DECODE, ForwardMode.TARGET_VERIFY) return self == ForwardMode.DECODE or self == ForwardMode.TARGET_VERIFY
def is_dummy_first(self): def is_dummy_first(self):
return self == ForwardMode.DUMMY_FIRST return self == ForwardMode.DUMMY_FIRST
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment