Unverified Commit 81a632ac authored by hlu1's avatar hlu1 Committed by GitHub
Browse files

[DeepseekV32] Enable flashmla_prefill kernel with fp8 kvcache (#11655)


Signed-off-by: default avatarHao Lu <14827759+hlu1@users.noreply.github.com>
parent 83b22400
......@@ -22,6 +22,10 @@ def _dequantize_k_cache_slow(
De-quantize the k-cache
"""
assert dv % tile_size == 0
original_ndim = quant_k_cache.ndim
if original_ndim == 3:
# set block_size = 1
quant_k_cache = quant_k_cache.unsqueeze(1)
num_tiles = dv // tile_size
num_blocks, block_size, h_k, _ = quant_k_cache.shape
assert h_k == 1
......@@ -45,8 +49,10 @@ def _dequantize_k_cache_slow(
cur_nope * cur_scales
)
result = result.view(num_blocks, block_size, 1, d)
return result
if original_ndim == 3:
return result.view(num_blocks, 1, -1)
else:
return result.view(num_blocks, block_size, 1, -1)
def _dequantize_k_cache_fast_wrapped(
......@@ -54,7 +60,10 @@ def _dequantize_k_cache_fast_wrapped(
dv: int = 512,
tile_size: int = 128,
) -> torch.Tensor:
# TODO the final API may be 2D instead of 4D, thus we convert them here
original_ndim = quant_k_cache.ndim
if original_ndim == 3:
# set block_size = 1
quant_k_cache = quant_k_cache.unsqueeze(1)
num_blocks, block_size, _, dim_quant = quant_k_cache.shape
assert dv == 512
assert dim_quant == 656
......@@ -63,7 +72,10 @@ def _dequantize_k_cache_fast_wrapped(
output = _dequantize_k_cache_fast(quant_k_cache)
return output.view(num_blocks, block_size, 1, -1)
if original_ndim == 3:
return output.view(num_blocks, 1, -1)
else:
return output.view(num_blocks, block_size, 1, -1)
def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
......@@ -85,7 +97,6 @@ def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
assert num_blocks_per_token == 5
assert dim_nope % group_size == 0
NUM_NOPE_BLOCKS = dim_nope // group_size
input_nope_q = quant_k_cache[:, :dim_nope]
input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view(
......@@ -102,7 +113,7 @@ def _dequantize_k_cache_fast(quant_k_cache, group_size: int = 128):
input_nope_q.stride(0),
input_nope_s.stride(0),
input_rope.stride(0),
NUM_NOPE_BLOCKS=NUM_NOPE_BLOCKS,
NUM_NOPE_BLOCKS=num_tiles,
GROUP_SIZE=group_size,
DIM_NOPE=dim_nope,
DIM_ROPE=dim_rope,
......@@ -159,5 +170,126 @@ def _dequantize_k_cache_fast_kernel(
tl.store(dst_ptr, data, mask=mask)
def dequantize_k_cache_paged(
quant_k_cache: torch.Tensor,
page_table_1_flattened: torch.Tensor,
group_size: int = 128,
) -> torch.Tensor:
"""
De-quantize the k-cache with paged layout
Args:
quant_k_cache: [total_num_tokens, 1, dim_quant] or [num_blocks, block_size, 1, dim_quant], the quantized k-cache in paged layout
page_table_1_flattened: [num_tokens], the flattened page_table_1 with the page indices in each requests concatenated together
Returns:
output: [num_tokens, 1, dim_nope + dim_rope], the de-quantized k-cache
"""
dim_quant = quant_k_cache.shape[-1]
assert (
dim_quant == 656
), f"dim_quant: {dim_quant} != 656 detected in dequantize_k_cache_paged"
quant_k_cache = quant_k_cache.view((-1, dim_quant))
total_num_tokens, _ = quant_k_cache.shape
num_tokens = page_table_1_flattened.shape[0]
assert num_tokens <= total_num_tokens
assert quant_k_cache.dtype == torch.float8_e4m3fn
dim_nope = 512
dim_rope = 64
num_tiles = dim_nope // group_size # 512 // 128 = 4
output = torch.empty(
(num_tokens, 1, dim_nope + dim_rope),
dtype=torch.bfloat16,
device=quant_k_cache.device,
)
# cdiv(512 + 64, 128) = 5
num_blocks_per_token = triton.cdiv(dim_nope + dim_rope, group_size)
assert num_blocks_per_token == 5
assert dim_nope % group_size == 0
input_nope_q = quant_k_cache[:, :dim_nope]
# [:, 512:512+4*4] = [:, 512:528]
input_nope_s = quant_k_cache[:, dim_nope : dim_nope + num_tiles * 4].view(
torch.float32
)
# [:, 528:]
input_rope = quant_k_cache[:, dim_nope + num_tiles * 4 :].view(torch.bfloat16)
_dequantize_k_cache_paged_kernel[(num_tokens, num_blocks_per_token)](
output,
input_nope_q,
input_nope_s,
input_rope,
page_table_1_flattened,
output.stride(0),
input_nope_q.stride(0),
input_nope_s.stride(0),
input_rope.stride(0),
NUM_NOPE_BLOCKS=num_tiles,
GROUP_SIZE=group_size,
DIM_NOPE=dim_nope,
DIM_ROPE=dim_rope,
)
return output
@triton.jit
def _dequantize_k_cache_paged_kernel(
output_ptr,
input_nope_q_ptr,
input_nope_s_ptr,
input_rope_ptr,
page_table_1_ptr,
output_stride_0: int,
input_nope_q_stride_0: int,
input_nope_s_stride_0: int,
input_rope_stride_0: int,
NUM_NOPE_BLOCKS: tl.constexpr,
GROUP_SIZE: tl.constexpr,
DIM_NOPE: tl.constexpr,
DIM_ROPE: tl.constexpr,
):
token_id = tl.program_id(0)
token_id_paged = tl.load(page_table_1_ptr + token_id).to(tl.int32)
raw_block_id = tl.program_id(1)
if raw_block_id < NUM_NOPE_BLOCKS:
# a. dequant nope
effective_block_id = raw_block_id
offs_q = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs_q < DIM_NOPE
ptr_q = input_nope_q_ptr + token_id_paged * input_nope_q_stride_0 + offs_q
ptr_s = (
input_nope_s_ptr
+ token_id_paged * input_nope_s_stride_0
+ effective_block_id
)
y_q = tl.load(ptr_q, mask=mask, other=0.0).to(tl.float32)
y_s = tl.load(ptr_s)
y = (y_q * y_s).to(output_ptr.dtype.element_ty)
dst_ptr = output_ptr + token_id * output_stride_0 + offs_q
tl.store(dst_ptr, y, mask=mask)
else:
# b. copy rope
effective_block_id = raw_block_id - NUM_NOPE_BLOCKS
offs = effective_block_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
mask = offs < DIM_ROPE
src_ptr = input_rope_ptr + token_id_paged * input_rope_stride_0 + offs
dst_ptr = output_ptr + token_id * output_stride_0 + DIM_NOPE + offs
data = tl.load(src_ptr, mask=mask).to(tl.bfloat16)
tl.store(dst_ptr, data, mask=mask)
if __name__ == "__main__":
raise Exception("UT is in quant_k_cache.py")
......@@ -206,6 +206,8 @@ def _quantize_k_cache_fast_kernel(
if __name__ == "__main__":
import dequant_k_cache
for num_blocks, block_size in [
(1, 1),
(10, 64),
......@@ -217,21 +219,9 @@ if __name__ == "__main__":
dtype=torch.bfloat16,
device="cuda",
)
# temp debug
# input_k_cache = (576 - torch.arange(num_blocks * block_size * 1 * dim_nope_and_rope, device="cuda")).to(torch.bfloat16).reshape(num_blocks, block_size, 1, dim_nope_and_rope)
ref_quant = _quantize_k_cache_slow(input_k_cache)
actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
# print(f"{input_k_cache=}")
# print(f"{ref_quant=}")
# print(f"{actual_quant=}")
# print(f"{ref_quant == actual_quant=}")
# print(f"{actual_quant.to(torch.float32) - ref_quant.to(torch.float32)=}")
# print(f"{ref_quant.view(torch.bfloat16)=}")
# print(f"{actual_quant.view(torch.bfloat16)=}")
# assert torch.all(ref_quant == actual_quant)
import dequant_k_cache
ref_ref_dequant = dequant_k_cache._dequantize_k_cache_slow(ref_quant)
ref_actual_dequant = dequant_k_cache._dequantize_k_cache_fast_wrapped(ref_quant)
......@@ -252,4 +242,46 @@ if __name__ == "__main__":
ref_ref_dequant, actual_actual_dequant, atol=0.2, rtol=0.2
)
# test dequant_k_cache_paged
page_table_1 = torch.arange(
num_blocks * block_size, dtype=torch.int32, device="cuda"
)
actual_dequant_paged = dequant_k_cache.dequantize_k_cache_paged(
actual_quant, page_table_1
).reshape(actual_actual_dequant.shape)
print(f"{torch.mean(actual_actual_dequant - actual_dequant_paged)=}")
torch.testing.assert_close(
ref_ref_dequant, actual_dequant_paged, atol=0.2, rtol=0.2
)
print("Passed")
print("Do benchmark...")
for num_blocks, block_size in [
(1, 64),
(64, 64),
(128, 64),
(256, 64),
(512, 64),
(1024, 64),
(2048, 64),
]:
dim_nope_and_rope = 512 + 64
input_k_cache = torch.randn(
(num_blocks, block_size, 1, dim_nope_and_rope),
dtype=torch.bfloat16,
device="cuda",
)
actual_quant = _quantize_k_cache_fast_wrapped(input_k_cache)
page_table_1 = torch.arange(
num_blocks * block_size, dtype=torch.int32, device="cuda"
)
def run_ans():
return dequant_k_cache.dequantize_k_cache_paged(actual_quant, page_table_1)
ans_time: float = triton.testing.do_bench(run_ans, warmup=10, rep=20) / 1000 # type: ignore
print(f"seq_kv: {num_blocks * block_size}, time: {ans_time * 1e6: 4.0f} us")
from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum, auto
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
import torch
from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.nsa.dequant_k_cache import dequantize_k_cache_paged
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.layers.attention.nsa.transform_index import (
......@@ -98,11 +100,27 @@ class NSAMetadata:
nsa_max_seqlen_q: Literal[1] = 1 # always 1 for decode, variable for extend
flashmla_metadata: Optional[NSAFlashMLAMetadata] = None
# The sum of sequence lengths for key, prefill only
seq_lens_sum: Optional[int] = None
# The flattened 1D page table with shape (seq_lens_sum,), prefill only
# this table is always with page_size = 1
page_table_1_flattened: Optional[torch.Tensor] = None
# The offset of topk indices in ragged kv, prefill only
# shape: (seq_lens_sum,)
topk_indices_offset: Optional[torch.Tensor] = None
class TopkTransformMethod(IntEnum):
# Transform topk indices to indices to the page table (page_size = 1)
PAGED = auto()
# Transform topk indices to indices to ragged kv (non-paged)
RAGGED = auto()
@dataclass(frozen=True)
class NSAIndexerMetadata(BaseIndexerMetadata):
attn_metadata: NSAMetadata
topk_transform_method: TopkTransformMethod
def get_seqlens_int32(self) -> torch.Tensor:
return self.attn_metadata.cache_seqlens_int32
......@@ -118,23 +136,36 @@ class NSAIndexerMetadata(BaseIndexerMetadata):
logits: torch.Tensor,
topk: int,
) -> torch.Tensor:
from sgl_kernel import fast_topk_transform_fused, fast_topk_v2
from sgl_kernel import (
fast_topk_transform_fused,
fast_topk_transform_ragged_fused,
fast_topk_v2,
)
if not NSA_FUSE_TOPK:
return fast_topk_v2(logits, self.get_seqlens_expanded(), topk)
# NOTE(dark): if fused, we return a transformed page table directly
return fast_topk_transform_fused(
score=logits,
lengths=self.get_seqlens_expanded(),
page_table_size_1=self.attn_metadata.page_table_1,
cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
topk=topk,
)
elif self.topk_transform_method == TopkTransformMethod.PAGED:
# NOTE(dark): if fused, we return a transformed page table directly
return fast_topk_transform_fused(
score=logits,
lengths=self.get_seqlens_expanded(),
page_table_size_1=self.attn_metadata.page_table_1,
cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
topk=topk,
)
elif self.topk_transform_method == TopkTransformMethod.RAGGED:
return fast_topk_transform_ragged_fused(
score=logits,
lengths=self.get_seqlens_expanded(),
topk_indices_offset=self.attn_metadata.topk_indices_offset,
topk=topk,
)
else:
assert False, f"Unsupported {self.topk_transform_method = }"
def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
assert seqlens.dtype == torch.int32 and seqlens.is_cuda
assert seqlens.dtype == torch.int32
return torch.nn.functional.pad(
torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)
)
......@@ -181,6 +212,7 @@ class NativeSparseAttnBackend(AttentionBackend):
global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill_backend
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode_backend
self.enable_auto_select_prefill_impl = NSA_PREFILL_IMPL == "flashmla_auto"
self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
......@@ -231,10 +263,16 @@ class NativeSparseAttnBackend(AttentionBackend):
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
assert forward_batch.seq_lens_cpu is not None
max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item() + draft_token_num)
# [b, max_seqlen_k]
page_table = forward_batch.req_to_token_pool.req_to_token[
forward_batch.req_pool_indices, :max_seqlen_k
]
page_table_1_flattened = None
topk_indices_offset = None
self.set_nsa_prefill_impl(forward_batch)
topk_transform_method = self.get_topk_transform_method()
if forward_batch.forward_mode.is_decode_or_idle():
extend_seq_lens_cpu = [1] * batch_size
max_seqlen_q = 1
......@@ -295,6 +333,7 @@ class NativeSparseAttnBackend(AttentionBackend):
else:
max_seqlen_q = max_seqlen_k
cu_seqlens_q = cu_seqlens_k
seqlens_expanded = torch.cat(
[
torch.arange(
......@@ -310,6 +349,24 @@ class NativeSparseAttnBackend(AttentionBackend):
)
]
)
if topk_transform_method == TopkTransformMethod.RAGGED:
page_table_1_flattened = torch.cat(
[
page_table[i, :kv_len]
for i, kv_len in enumerate(
forward_batch.seq_lens_cpu.tolist(),
)
]
)
assert (
page_table_1_flattened.shape[0] == forward_batch.seq_lens_sum
), f"{page_table_1_flattened.shape[0] = } must be the same as {forward_batch.seq_lens_sum = }"
topk_indices_offset = torch.repeat_interleave(
cu_seqlens_k[:-1],
forward_batch.extend_seq_lens,
)
else:
assert False, f"Unsupported {forward_batch.forward_mode = }"
......@@ -328,7 +385,9 @@ class NativeSparseAttnBackend(AttentionBackend):
max_seq_len_k=max_seqlen_k,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
seq_lens_sum=forward_batch.seq_lens_sum,
page_table_1=page_table,
page_table_1_flattened=page_table_1_flattened,
flashmla_metadata=(
self._compute_flashmla_metadata(
cache_seqlens=nsa_cache_seqlens_int32,
......@@ -344,6 +403,7 @@ class NativeSparseAttnBackend(AttentionBackend):
nsa_extend_seq_lens_list=extend_seq_lens_cpu,
real_page_table=self._transform_table_1_to_real(page_table),
nsa_max_seqlen_q=1,
topk_indices_offset=topk_indices_offset,
)
self.forward_metadata = metadata
......@@ -396,6 +456,8 @@ class NativeSparseAttnBackend(AttentionBackend):
forward_mode: ForwardMode,
spec_info: Optional[SpecInput],
):
self.set_nsa_prefill_impl(forward_batch=None)
"""Initialize forward metadata for capturing CUDA graph."""
if forward_mode.is_decode_or_idle():
# Normal Decode
......@@ -586,6 +648,8 @@ class NativeSparseAttnBackend(AttentionBackend):
"""Initialize forward metadata for replaying CUDA graph."""
assert seq_lens_cpu is not None
self.set_nsa_prefill_impl(forward_batch=None)
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
req_pool_indices = req_pool_indices[:bs]
......@@ -780,17 +844,31 @@ class NativeSparseAttnBackend(AttentionBackend):
q_rope = q_all[:, :, layer.v_head_dim :]
# NOTE(dark): here, we use page size = 1
topk_transform_method = self.get_topk_transform_method()
if NSA_FUSE_TOPK:
page_table_1 = topk_indices
else:
assert metadata.nsa_extend_seq_lens_list is not None
page_table_1 = transform_index_page_table_prefill(
page_table=metadata.page_table_1,
topk_indices=topk_indices,
extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
page_size=1,
)
if topk_transform_method == TopkTransformMethod.RAGGED:
topk_indices_offset = metadata.topk_indices_offset
assert topk_indices_offset is not None
mask = topk_indices != -1
topk_indices_offset = (
topk_indices_offset.unsqueeze(1)
if topk_indices_offset.ndim == 1
else topk_indices_offset
)
topk_indices = torch.where(
mask, topk_indices + topk_indices_offset, topk_indices
)
elif topk_transform_method == TopkTransformMethod.PAGED:
assert metadata.nsa_extend_seq_lens_list is not None
page_table_1 = transform_index_page_table_prefill(
page_table=metadata.page_table_1,
topk_indices=topk_indices,
extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
page_size=1,
)
if NSA_PREFILL_IMPL == "tilelang":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
......@@ -804,6 +882,22 @@ class NativeSparseAttnBackend(AttentionBackend):
elif NSA_PREFILL_IMPL == "flashmla_sparse":
if q_rope is not None:
q_all = torch.cat([q_nope, q_rope], dim=-1)
# NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 has no effect here,
# because the flashmla_sparse kernel doesn't support fp8 compute
if topk_transform_method == TopkTransformMethod.RAGGED:
if any(forward_batch.extend_prefix_lens_cpu):
page_table_1_flattened = (
self.forward_metadata.page_table_1_flattened
)
assert page_table_1_flattened is not None
kv_cache = dequantize_k_cache_paged(
kv_cache, page_table_1_flattened
)
else:
kv_cache = torch.cat([k, k_rope], dim=-1)
page_table_1 = topk_indices
return self._forward_flashmla_sparse(
q_all=q_all,
kv_cache=kv_cache,
......@@ -1121,10 +1215,52 @@ class NativeSparseAttnBackend(AttentionBackend):
"""Get the fill value for sequence length in CUDA graph."""
return 1
def set_nsa_prefill_impl(self, forward_batch: Optional[ForwardBatch] = None) -> str:
from sglang.srt.utils import is_blackwell
global NSA_PREFILL_IMPL
if self.enable_auto_select_prefill_impl:
if self.nsa_kv_cache_store_fp8:
if (
# TODO(hlu1): enable MTP
is_blackwell()
and forward_batch is not None
and forward_batch.forward_mode.is_extend()
and forward_batch.spec_algorithm.is_none()
):
total_kv_tokens = forward_batch.seq_lens_sum
total_q_tokens = forward_batch.extend_num_tokens
# Heuristic based on benchmarking flashmla_kv vs flashmla_sparse + dequantize_k_cache_paged
if total_kv_tokens < total_q_tokens * 512:
NSA_PREFILL_IMPL = "flashmla_sparse"
return
NSA_PREFILL_IMPL = "flashmla_kv"
else:
# bf16 kv cache
NSA_PREFILL_IMPL = "flashmla_sparse"
def get_topk_transform_method(self) -> TopkTransformMethod:
"""
NSA_FUSE_TOPK controls whether to fuse the topk transform into the topk kernel.
This method is used to select the topk transform method which can be fused or unfused.
"""
if (
# disable for MTP
self.nsa_kv_cache_store_fp8
and NSA_PREFILL_IMPL == "flashmla_sparse"
):
topk_transform_method = TopkTransformMethod.RAGGED
else:
topk_transform_method = TopkTransformMethod.PAGED
return topk_transform_method
def get_indexer_metadata(
self, layer_id: int, forward_batch: ForwardBatch
) -> NSAIndexerMetadata:
return NSAIndexerMetadata(attn_metadata=self.forward_metadata)
return NSAIndexerMetadata(
attn_metadata=self.forward_metadata,
topk_transform_method=self.get_topk_transform_method(),
)
def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
from flash_mla import get_mla_metadata
......
......@@ -135,7 +135,16 @@ GRAMMAR_BACKEND_CHOICES = ["xgrammar", "outlines", "llguidance", "none"]
DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"]
NSA_CHOICES = ["flashmla_sparse", "flashmla_kv", "fa3", "tilelang", "aiter"]
DEFAULT_LORA_EVICTION_POLICY = "lru"
NSA_CHOICES = [
"flashmla_sparse",
"flashmla_kv",
"flashmla_auto",
"fa3",
"tilelang",
"aiter",
]
RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"]
......@@ -1022,16 +1031,30 @@ class ServerArgs:
import torch
major, _ = torch.cuda.get_device_capability()
if major >= 10:
self.kv_cache_dtype = "fp8_e4m3"
logger.warning("Setting KV cache dtype to fp8.")
if self.kv_cache_dtype == "auto":
self.kv_cache_dtype = "fp8_e4m3" if major >= 10 else "bfloat16"
logger.warning(
f"Setting KV cache dtype to {self.kv_cache_dtype} for DeepSeek NSA."
)
if self.kv_cache_dtype == "bf16":
self.kv_cache_dtype = "bfloat16"
assert self.kv_cache_dtype in [
"bfloat16",
"fp8_e4m3",
], "DeepSeek NSA only supports bf16/bfloat16 or fp8_e4m3 kv_cache_dtype"
if self.kv_cache_dtype == "fp8_e4m3":
self.nsa_prefill_backend = "flashmla_kv"
# flashmla_auto dispatches to flashmla_sparse/flashmla_kv based on hardware and heuristics
self.nsa_prefill_backend = "flashmla_auto"
self.nsa_decode_backend = "flashmla_kv"
logger.warning(
"Setting NSA backend to flashmla_kv for FP8 KV Cache."
"Setting NSA backend to flashmla_auto for prefill and flashmla_kv for decode for FP8 KV Cache."
)
else:
# set prefill/decode backends for Blackwell. The default settings are for Hopper.
if major >= 10:
self.nsa_prefill_backend = "flashmla_sparse"
self.nsa_decode_backend = "flashmla_sparse"
# Logging env vars for NSA
from sglang.srt.layers.attention.nsa.utils import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment