Commit a1175a4e authored by maxiao1's avatar maxiao1
Browse files

Merge remote-tracking branch 'origin/v0.5.4_dev' into sglang_v0.5.5

parents 0c006b88 31653dd9
......@@ -848,10 +848,12 @@ class BenchmarkMetrics:
mean_ttft_ms: float
median_ttft_ms: float
std_ttft_ms: float
p95_ttft_ms: float
p99_ttft_ms: float
mean_tpot_ms: float
median_tpot_ms: float
std_tpot_ms: float
p95_tpot_ms: float
p99_tpot_ms: float
mean_itl_ms: float
median_itl_ms: float
......@@ -1721,10 +1723,12 @@ def calculate_metrics(
* 1000, # ttfts is empty if streaming is not supported by backend
median_ttft_ms=np.median(ttfts or 0) * 1000,
std_ttft_ms=np.std(ttfts or 0) * 1000,
p95_ttft_ms=np.percentile(ttfts or 0, 95) * 1000,
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
mean_tpot_ms=np.mean(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000,
p95_tpot_ms=np.percentile(tpots or 0, 95) * 1000,
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
mean_itl_ms=np.mean(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000,
......@@ -2052,6 +2056,12 @@ async def benchmark(
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
print("{:<40} {:<10.2f}".format("P95 TTFT (ms):", metrics.p95_ttft_ms))
print("{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-"))
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
print("{:<40} {:<10.2f}".format("P95 TPOT (ms):", metrics.p95_tpot_ms))
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
......
......@@ -4,10 +4,24 @@ from typing import List, Optional, Tuple
import torch
from sglang.srt.utils import is_hip, is_hpu, is_npu
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
try:
from lmslim import quant_ops
from lmslim import quant_tools
except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
try:
import lightop
except Exception:
print("INFO: Please install lightop if you want to infer awq of marlin.\n")
logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = get_bool_env_var(
"USE_VLLM_CUSTOM_ALLREDUCE", default="false"
)
use_dcu_custom_allreduce = get_bool_env_var(
"USE_DCU_CUSTOM_ALLREDUCE", default="true"
)
if not is_hpu():
try:
......@@ -15,6 +29,11 @@ if not is_hpu():
except ImportError as e:
logger.warning("Failed to import from custom_ar with %r", e)
if use_dcu_custom_allreduce:
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
if not is_hip() and not is_npu():
custom_op = sgl_kernel.allreduce
......@@ -54,8 +73,79 @@ if not is_hip() and not is_npu():
) -> None:
custom_op.register_graph_buffers(fa, handles, offsets)
elif is_hip and use_dcu_custom_allreduce:
# custom ar
def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor,
rank: int, fully_connected: bool) -> int:
return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
fully_connected)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
reg_buffer_sz_bytes: int) -> None:
torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer,
reg_buffer_sz_bytes)
def dispose(fa: int) -> None:
torch.ops._C_custom_ar.dispose(fa)
def meta_size() -> int:
return torch.ops._C_custom_ar.meta_size()
def register_buffer(fa: int, ipc_tensors: list[int]) -> None:
return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
def get_graph_buffer_ipc_meta(fa: int) -> tuple[list[int], list[int]]:
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa: int, handles: list[list[int]],
offsets: list[list[int]]) -> None:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]:
return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size)
def open_mem_handle(mem_handle: torch.Tensor):
return torch.ops._C_custom_ar.open_mem_handle(mem_handle)
def free_shared_buffer(ptr: int) -> None:
torch.ops._C_custom_ar.free_shared_buffer(ptr)
def read_cache(
keys: torch.Tensor,
values: torch.Tensor,
key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache_dtype: str
) -> None:
torch.ops._C_cache_ops.read_cache(keys, values, key_caches,
value_caches, slot_mapping,
kv_cache_dtype)
def write_cache_multi_layers(
keys: torch.Tensor,
values: torch.Tensor,
key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache_dtype: str
) -> None:
torch.ops._C_cache_ops.write_cache_multi_layers(keys, values, key_caches,
value_caches, slot_mapping,
kv_cache_dtype)
else:
# ROCM custom allreduce
# sgl_kernel ROCM custom allreduce
def init_custom_ar(
meta: torch.Tensor,
......@@ -163,3 +253,83 @@ def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None:
return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks)
def triton_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
best_config:Optional[list] = None) -> torch.Tensor:
return quant_ops.triton_scaled_mm(a, b,scale_a,scale_b,out_dtype,bias,best_config)
def cutlass_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
`cutlass_scaled_mm` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
In order to support blockwise scaling like found in DeepSeek V3 we also
support extended "group" broadcast rules. We extend the numpy-style
broadcasting rules with the following rule:
"if the extent of a dimension in the source shape is between 1 and
corresponding extent in the target shape we repeat each element along
that dimension src_shape[dim] // target_shape[dim] times consecutively"
example if we have:
a = [[1, 2], and target_shape = (2, 4)
[3, 4]]
then we would expand a to:
a = [[1, 1, 2, 2],
[3, 3, 4, 4]]
currently we only support the case:
scale_a.shape * [1, 128] == a.shape
scale_b.shape * [128, 128] == b.shape
"""
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == b.shape[
1] and bias.dtype == out_dtype
# m = a.shape[0]
# n = b.shape[1]
# cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
# if current_platform.is_rocm() or not cutlass_compatible_b:
# from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa
# triton_scaled_mm)
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
# out = torch.empty((m, n), dtype=out_dtype, device=a.device)
# torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
# return out
#return quant_ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def rocblas_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def triton_int8_gemm_helper(m: int,
n: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
use_bias: bool,
out_dtype: type[torch.dtype] = torch.float16,
device: str = "cuda:0",
best_config:Optional[list] = None,
repeat:Optional[int] = 2):
return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config,repeat)
\ No newline at end of file
......@@ -635,7 +635,9 @@ class ModelConfig:
"petit_nvfp4",
"quark",
"mxfp4",
"auto-round",
"slimquant_w4a8_marlin",
"w8a8_int8",
"slimquant_marlin",
]
optimized_quantization_methods = [
"fp8",
......@@ -655,6 +657,8 @@ class ModelConfig:
"qoq",
"w4afp8",
"petit_nvfp4",
"slimquant_w4a8_marlin",
"slimquant_marlin",
]
compatible_quantization_methods = {
"modelopt_fp8": ["modelopt"],
......
......@@ -34,6 +34,21 @@ except ImportError:
_is_cuda = is_cuda()
_is_hip = is_hip()
try:
if ops.use_vllm_custom_allreduce and not _is_hip:
# Use vLLM custom allreduce
ops.meta_size()
elif ops.use_dcu_custom_allreduce:
ops.meta_size()
else:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import sgl_kernel # noqa: F401
custom_ar = True
except Exception:
# For CPUs
custom_ar = False
logger = logging.getLogger(__name__)
......@@ -416,3 +431,274 @@ class CustomAllreduce:
def __del__(self):
self.close()
class DCUCustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8, 16]
# max_size: max supported allreduce size
def __init__(self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_size=8192 * 512) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
if not custom_ar:
# disable because of missing custom allreduce library
# e.g. in a non-GPU environment
logger.info("Custom allreduce is disabled because "
"of missing custom allreduce library")
return
self.group = group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"CustomAllreduce should be attached to a non-NCCL group.")
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom allreduce for multi-node case.
logger.warning(
"Custom allreduce is disabled because this process group"
" spans across nodes.")
return
rank = dist.get_rank(group=self.group)
self.rank = rank
world_size = dist.get_world_size(group=self.group)
# if world_size > envs.VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX:
if world_size > 16:
return
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.",
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(torch.cuda.device_count()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id],
dtype=torch.int,
device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu")
for _ in range(world_size)
]
dist.all_gather(gather_list, tensor, group=self.group)
physical_device_ids = [t.item() for t in gather_list]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
# assert current_platform.is_cuda_alike()
# fully_connected = current_platform.is_fully_connected(
# physical_device_ids)
if _is_cuda or _is_hip:
fully_connected = is_full_nvlink(physical_device_ids, world_size)
# if world_size > 2 and not fully_connected:
if not fully_connected:
max_size = 32 * 8192 * 2
# if not envs.VLLM_PCIE_USE_CUSTOM_ALLREDUCE:
# logger.warning(
# "Custom allreduce is disabled because it's not supported on"
# " more than two PCIe-only GPUs. To silence this warning, "
# "specify disable_custom_all_reduce=True explicitly.")
# return
logger.warning(
"We are using PCIe's custom allreduce."
"If the performance is poor, we can add "
"--disable-custom-all-reduce in the instruction.")
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if not _is_hip and not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.")
return
self.disabled = False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
group=group,
uncached=True)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(8 * 1024 * 1024,
dtype=torch.uint8,
device=self.device)
self.max_size = max_size
self.rank = rank
self.world_size = world_size
self.fully_connected = fully_connected
self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
self.fully_connected)
ops.register_buffer(self._ptr, self.buffer_ptrs)
@contextmanager
def capture(self):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try:
self._IS_CAPTURING = True
yield
finally:
self._IS_CAPTURING = False
if not self.disabled:
self.register_graph_buffers()
def register_graph_buffers(self):
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
logger.info("Registering %d cuda graph addresses", len(offset))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [[None, None]
for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(all_data[i],
src=rank,
group=self.group,
device="cpu")
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):
if self.disabled:
return False
inp_size = inp.numel() * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
return inp_size <= self.max_size
def all_reduce(self,
inp: torch.Tensor,
*,
out: torch.Tensor = None,
registered: bool = False):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if out is None:
out = torch.empty_like(inp)
if registered:
ops.all_reduce(self._ptr, inp, out, 0, 0)
else:
ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank],
self.max_size)
return out
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=False)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return torch.empty_like(input)
else:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return self.all_reduce(input, registered=False)
def close(self):
if not self.disabled and self._ptr:
if ops is not None:
ops.dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
def __del__(self):
self.close()
@staticmethod
def create_shared_buffer(size_in_bytes: int,
group: Optional[ProcessGroup] = None,
uncached: Optional[bool] = False) -> list[int]:
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
pointers: list[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer) # type: ignore
else:
pointers.append(ops.open_mem_handle(h))
return pointers
@staticmethod
def free_shared_buffer(pointers: list[int],
group: Optional[ProcessGroup] = None,
rank: Optional[int] = 0) -> None:
if rank is None:
rank = dist.get_rank(group=group)
if ops is not None:
ops.free_shared_buffer(pointers[rank])
......@@ -54,6 +54,7 @@ from sglang.srt.utils import (
is_xpu,
supports_custom_op,
)
from sglang.srt import _custom_ops as ops
_is_npu = is_npu()
_is_cpu = is_cpu()
......@@ -327,7 +328,7 @@ class GroupCoordinator:
# Lazy import to avoid documentation build error
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce,
CustomAllreduce, DCUCustomAllreduce
)
from sglang.srt.distributed.device_communicators.pymscclpp import (
PyMscclppCommunicator,
......@@ -371,10 +372,17 @@ class GroupCoordinator:
if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation.
try:
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
)
if is_hip() and ops.use_dcu_custom_allreduce:
self.ca_comm = DCUCustomAllreduce(
group=self.cpu_group,
device=self.device,
)
else:
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
max_size=ca_max_size,
)
except Exception as e:
logger.warning(
f"Setup Custom allreduce failed with {e}. To silence this "
......
......@@ -188,6 +188,17 @@ class Envs:
SGLANG_USE_AITER = EnvBool(False)
SGLANG_ROCM_FUSED_DECODE_MLA = EnvBool(False)
SGLANG_ROCM_DISABLE_LINEARQUANT = EnvBool(False)
# DCU Lightop
SGLANG_USE_LIGHTOP = EnvBool(False)
# Fused
SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD = EnvBool(False)
SGLANG_USE_OPT_CAT = EnvBool(False)
SGLANG_USE_FUSED_RMS_QUANT = EnvBool(False)
SGLANG_USE_FUSED_SILU_MUL_QUANT = EnvBool(False)
# Quantization
SGLANG_INT4_WEIGHT = EnvBool(False)
......
......@@ -99,7 +99,6 @@ def create_triton_backend(runner):
return TritonAttnBackend(runner)
@register_attention_backend("torch_native")
def create_torch_native_backend(runner):
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
......@@ -120,6 +119,11 @@ def create_flashmla_backend(runner):
return FlashMLABackend(runner)
@register_attention_backend("dcu_mla")
def create_dcu_mla_backend(runner):
from sglang.srt.layers.attention.dcu_mla_backend import DCUMLABackend
return DCUMLABackend(runner)
@register_attention_backend("fa3")
def create_flashattention_v3_backend(runner):
......
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union
import torch
import triton
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sgl_kernel.flash_mla import dcu_create_flashmla_kv_indices
from sglang.srt.utils import get_bool_env_var
try:
from flash_mla import (
flash_mla_with_kvcache,
flash_mla_with_kvcache_quantization,
get_mla_metadata
)
_has_flash_mla = True
except Exception:
try:
from vllm.attention.ops.flashmla import (
flash_mla_with_kvcache,
get_mla_metadata
)
_has_flash_mla = False
except Exception:
raise ImportError(
"Can not import FlashMLA。Please perform the following operations to use flashmla:\n"
" pip install flash-mla\n"
" or\n"
" pip install vllm"
)
PAGE_SIZE = 64 # 强制64
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInput
@dataclass
class VllmMLADecodeMetadata:
flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
num_splits: Optional[torch.Tensor] = None
block_kv_indices: Optional[torch.Tensor] = None
def __init__(
self,
flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
num_splits: Optional[torch.Tensor] = None,
block_kv_indices: Optional[torch.Tensor] = None,
):
self.flashmla_metadata = flashmla_metadata
self.num_splits = num_splits
self.block_kv_indices = block_kv_indices
class DCUMLABackend(AttentionBackend):
def __init__(
self,
model_runner: "ModelRunner",
skip_prefill: bool = False,
kv_indptr_buf: Optional[torch.Tensor] = None,
kv_last_page_len_buf: Optional[torch.Tensor] = None,
):
super().__init__()
if model_runner.server_args.page_size != PAGE_SIZE:
raise ValueError(
f"dcu_mla backend requires page_size={PAGE_SIZE}, "
f"but got the {model_runner.server_args.page_size}"
)
self.num_q_heads = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
self.v_head_dim = model_runner.model_config.v_head_dim
self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
self.data_type = model_runner.kv_cache_dtype
self.q_data_type = model_runner.dtype
self.device = model_runner.device
self.k_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device)
self.max_context_len = model_runner.model_config.context_len
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.forward_metadata: Union[VllmMLADecodeMetadata] = None
self.skip_prefill = skip_prefill
if not skip_prefill:
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
self.flashattn_backend = FlashAttentionBackend(
model_runner,
skip_prefill=False,
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
use_sglang_create_flashmla_kv_indices_triton = get_bool_env_var("SGLANG_CREATE_FLASHMLA_KV_INDICES_TRITON")
bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle():
max_seqlen_pad = triton.cdiv(
forward_batch.seq_lens_cpu.max().item(), PAGE_SIZE
)
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
-1,
dtype=torch.int32,
device=forward_batch.seq_lens.device
)
if use_sglang_create_flashmla_kv_indices_triton:
dcu_create_flashmla_kv_indices(
req_to_token_ptr = self.req_to_token,
req_pool_indices_ptr = forward_batch.req_pool_indices,
page_kernel_lens_ptr = forward_batch.seq_lens,
kv_start_idx = None,
kv_indices_ptr = block_kv_indices,
req_to_token_ptr_stride = self.req_to_token.stride(0),
kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32),
self.num_q_heads,
1
)
self.forward_metadata = VllmMLADecodeMetadata(
mla_metadata,
num_splits,
block_kv_indices
)
elif forward_batch.forward_mode.is_target_verify():
seq_lens_cpu = forward_batch.seq_lens_cpu + self.num_draft_tokens
seq_lens = forward_batch.seq_lens + self.num_draft_tokens
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
-1,
dtype=torch.int32,
device=seq_lens.device,
)
if use_sglang_create_flashmla_kv_indices_triton:
dcu_create_flashmla_kv_indices(
req_to_token_ptr = self.req_to_token,
req_pool_indices_ptr = forward_batch.req_pool_indices,
page_kernel_lens_ptr = forward_batch.seq_lens,
kv_start_idx = None,
kv_indices_ptr = block_kv_indices,
req_to_token_ptr_stride = self.req_to_token.stride(0),
kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads,
1,
)
self.forward_metadata = VllmMLADecodeMetadata(
mla_metadata,
num_splits,
block_kv_indices
)
else:
if not self.skip_prefill:
# === DRAFT_EXTEND_V2 MLA metadata === nhb
if forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND_V2:
bs = forward_batch.batch_size
seq_lens_cpu = forward_batch.seq_lens_cpu
seq_lens = forward_batch.seq_lens
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
block_kv_indices = torch.full(
(bs, max_seqlen_pad),
-1,
dtype=torch.int32,
device=seq_lens.device,
)
# 调用 Triton kernel 生成 block_kv_indices
if use_sglang_create_flashmla_kv_indices_triton:
dcu_create_flashmla_kv_indices(
req_to_token_ptr = self.req_to_token.to(torch.int32),
req_pool_indices_ptr = forward_batch.req_pool_indices.to(torch.int32),
page_kernel_lens_ptr = forward_batch.seq_lens.to(torch.int32),
kv_start_idx = None,
kv_indices_ptr = block_kv_indices.to(torch.int32),
req_to_token_ptr_stride = self.req_to_token.stride(0),
kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
# MLA
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_q_heads,
1,
)
# save forward_metadata
self.forward_metadata = VllmMLADecodeMetadata(
mla_metadata,
num_splits,
block_kv_indices,
)
self.flashattn_backend.init_forward_metadata(forward_batch)
def init_cuda_graph_state(
self,
max_bs: int,
max_num_tokens: int,
block_kv_indices: Optional[torch.Tensor] = None,
):
if block_kv_indices is None:
cuda_graph_kv_indices = torch.full(
(max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
1,
dtype=torch.int32,
device="cuda",
)
else:
cuda_graph_kv_indices = block_kv_indices
if self.num_draft_tokens:
mla_metadata, num_splits = get_mla_metadata(
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
self.num_draft_tokens * self.num_q_heads,
1,
)
else:
mla_metadata, num_splits = get_mla_metadata(
torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
self.num_q_heads,
1,
)
self.cuda_graph_mla_metadata = mla_metadata
self.cuda_graph_num_splits = num_splits
self.cuda_graph_kv_indices = cuda_graph_kv_indices
def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional["SpecInput"],
):
if forward_mode.is_decode_or_idle():
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), num_q_heads, 1
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata = VllmMLADecodeMetadata(
self.cuda_graph_mla_metadata,
self.cuda_graph_num_splits[: bs + 1],
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
)
elif forward_mode.is_target_verify():
seq_lens = seq_lens + self.num_draft_tokens
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), self.num_draft_tokens * self.num_q_heads, 1
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata = VllmMLADecodeMetadata(
self.cuda_graph_mla_metadata,
self.cuda_graph_num_splits[: bs + 1],
self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
)
else:
if not self.skip_prefill:
self.flashattn_backend.init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_mode,
spec_info,
)
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional["SpecInput"],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
assert seq_lens_cpu is not None
seq_lens = seq_lens[:bs]
seq_lens_cpu = seq_lens_cpu[:bs]
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
num_q_heads = self.num_q_heads * (self.num_draft_tokens or 1)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), num_q_heads, 1
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata.flashmla_metadata = self.cuda_graph_mla_metadata
self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1]
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
:bs, :max_seqlen_pad
]
elif forward_mode.is_target_verify():
seq_lens = seq_lens[:bs] + self.num_draft_tokens
seq_lens_cpu = seq_lens_cpu[:bs] + self.num_draft_tokens
max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices[:bs],
seq_lens,
None,
self.cuda_graph_kv_indices,
self.req_to_token.stride(0),
self.cuda_graph_kv_indices.stride(0),
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), self.num_draft_tokens * self.num_q_heads, 1
)
self.cuda_graph_mla_metadata.copy_(mla_metadata)
self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
self.forward_metadata.flashmla_metadata = self.cuda_graph_mla_metadata
self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1]
self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
:bs, :max_seqlen_pad
]
else:
if not self.skip_prefill:
self.flashattn_backend.init_forward_metadata_replay_cuda_graph(
bs,
req_pool_indices,
seq_lens,
seq_lens_sum,
encoder_lens,
forward_mode,
spec_info,
seq_lens_cpu,
)
def get_cuda_graph_seq_len_fill_value(self):
return 1
def _call_decode(self, reshape_q: torch.Tensor, k_cache_reshaped: torch.Tensor,
block_table: torch.Tensor, cache_seqlens: torch.Tensor,
scaling: float):
o, _ = flash_mla_with_kvcache(
q=reshape_q,
k_cache=k_cache_reshaped,
block_table=block_table,
cache_seqlens=cache_seqlens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=self.forward_metadata.num_splits,
softmax_scale=scaling,
causal=True,
)
return o
def _call_fp8_decode(self, reshape_q: torch.Tensor, k_cache_reshaped: torch.Tensor,
block_table: torch.Tensor, cache_seqlens: torch.Tensor,
scaling: float, k_scale=None, kv_cache_dtype=None):
assert _has_flash_mla, "FP8 KV cache 需要flash_mla包"
o, _ = flash_mla_with_kvcache_quantization(
q=reshape_q,
k_cache=k_cache_reshaped,
block_table=block_table,
cache_seqlens=cache_seqlens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
num_splits=self.forward_metadata.num_splits,
softmax_scale=scaling,
causal=True,
k_scale=k_scale,
kv_cache_dtype=kv_cache_dtype,
)
return o
@torch._dynamo.disable() # NOTE: FP8 cache decode不支持compile
def forward_decode(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: "RadixAttention",
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
):
cache_loc = forward_batch.out_cache_loc
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
num_draft_tokens = self.num_draft_tokens if self.num_draft_tokens is not None else 0
if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz,
torch.float8_e5m2, torch.float8_e5m2fnuz):
if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn):
kv_cache_dtype="fp8_e4m3"
else:
kv_cache_dtype="fp8_e5m2"
k_scale = layer.k_scale if layer.k_scale is not None else self.k_scale
o = self._call_fp8_decode(
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling,
k_scale,
kv_cache_dtype=kv_cache_dtype,
)
else:
o = self._call_decode(
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@torch._dynamo.disable()
def forward_extend(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
layer: "RadixAttention",
forward_batch: ForwardBatch,
save_kv_cache: bool = True,
# For multi-head latent attention
q_rope: Optional[torch.Tensor] = None,
k_rope: Optional[torch.Tensor] = None,
sinks=None,
):
if ((
forward_batch.forward_mode == ForwardMode.EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND)
):
if not self.skip_prefill:
return self.flashattn_backend.forward_extend(
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope, sinks
)
else:
raise RuntimeError("skip prefill but use forward_extend")
cache_loc = forward_batch.out_cache_loc
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer,
cache_loc,
k,
v,
)
bs = forward_batch.batch_size
k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
k_cache_reshaped = k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim)
num_draft_tokens = self.num_draft_tokens if self.num_draft_tokens is not None else 0
if self.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz,
torch.float8_e5m2, torch.float8_e5m2fnuz):
if self.data_type in (torch.float8_e4m3fnuz, torch.float8_e4m3fn):
kv_cache_dtype="fp8_e4m3"
else:
kv_cache_dtype="fp8_e5m2"
k_scale = layer.k_scale if layer.k_scale is not None else self.k_scale
o = self._call_fp8_decode(
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling,
k_scale,
kv_cache_dtype=kv_cache_dtype,
)
else:
o = self._call_decode(
reshape_q,
k_cache_reshaped,
self.forward_metadata.block_kv_indices[:bs],
(forward_batch.seq_lens + num_draft_tokens).to(torch.int32),
layer.scaling,
)
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
class DCUMLAMultiStepDraftBackend:
"""
Wrap multiple flashmla attention backends as one for multiple consecutive
draft decoding steps.
"""
def __init__(
self,
model_runner: ModelRunner,
topk: int,
speculative_num_steps: int,
):
if topk > 1:
raise ValueError(
"Currently FlashMLA only supports topk=1 for speculative decoding"
)
self.topk = topk
self.speculative_num_steps = speculative_num_steps
max_bs = model_runner.req_to_token_pool.size * self.topk
self.kv_indptr = torch.zeros(
(
self.speculative_num_steps,
max_bs + 1,
),
dtype=torch.int32,
device=model_runner.device,
)
self.attn_backends = []
for i in range(self.speculative_num_steps - 1):
self.attn_backends.append(
DCUMLABackend(
model_runner,
skip_prefill=True,
kv_indptr_buf=self.kv_indptr[i],
kv_last_page_len_buf=None,
)
)
def common_template(
self,
forward_batch: ForwardBatch,
call_fn: Callable,
):
assert forward_batch.spec_info is not None
for i in range(self.speculative_num_steps - 1):
call_fn(i, forward_batch)
def init_forward_metadata(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
self.attn_backends[i].init_forward_metadata(forward_batch)
self.common_template(forward_batch, call_fn)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, block_kv_indices=None
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
forward_batch.batch_size,
forward_batch.batch_size * self.topk,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
)
self.common_template(forward_batch, call_fn)
def init_forward_metadata_replay_cuda_graph(
self, forward_batch: ForwardBatch, bs: int
):
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,
encoder_lens=None,
forward_mode=ForwardMode.DECODE,
spec_info=forward_batch.spec_info,
seq_lens_cpu=forward_batch.seq_lens_cpu,
)
self.common_template(forward_batch, call_fn)
......@@ -9,7 +9,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import torch
import torch.nn.functional as F
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from sgl_kernel.sparse_flash_attn import (
convert_vertical_slash_indexes,
convert_vertical_slash_indexes_mergehead,
......
......@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
from sgl_kernel import merge_state_v2
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache
@dataclass
......@@ -328,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend):
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
self.skip_prefill = skip_prefill
self.is_hybrid = model_runner.is_hybrid
self.k_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device)
self.v_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device)
if self.is_hybrid:
self.full_to_swa_index_mapping = (
model_runner.token_to_kv_pool.full_to_swa_index_mapping
......@@ -596,9 +599,11 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch.req_pool_indices, : metadata.max_seq_len_k
]
if any(
forward_batch.extend_prefix_lens_cpu
) or forward_batch.forward_mode.is_draft_extend(include_v2=True):
if (
any(forward_batch.extend_prefix_lens_cpu)
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND_V2 #nhb
):
extend_seq_lens = forward_batch.extend_seq_lens
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
metadata.cu_seqlens_q = torch.nn.functional.pad(
......@@ -608,10 +613,13 @@ class FlashAttentionBackend(AttentionBackend):
metadata.max_seq_len_q = metadata.max_seq_len_k
metadata.cu_seqlens_q = metadata.cu_seqlens_k
# Setup local attention if enabled
if forward_batch.forward_mode == ForwardMode.EXTEND:
# # Setup local attention if enabled
# if forward_batch.forward_mode == ForwardMode.EXTEND:
# self._init_local_attn_metadata(forward_batch, metadata, device)
if forward_batch.forward_mode in (ForwardMode.EXTEND, ForwardMode.DRAFT_EXTEND_V2):
self._init_local_attn_metadata(forward_batch, metadata, device)
# Encoder metadata for cross attention
if forward_batch.encoder_lens is not None:
assert (
......@@ -668,10 +676,11 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not self.use_mla:
if k_rope is None:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
layer, cache_loc, k, v, #layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
......@@ -690,7 +699,8 @@ class FlashAttentionBackend(AttentionBackend):
layer.sliding_window_size is not None and layer.sliding_window_size > -1
)
window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
k_descale, v_descale = None, None
# k_descale, v_descale = None, None
k_descale, v_descale = self.k_scale, self.v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
......@@ -704,7 +714,7 @@ class FlashAttentionBackend(AttentionBackend):
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype)
# q = q.to(self.kv_cache_dtype)
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
causal = True
......@@ -774,61 +784,59 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)
result = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
if forward_batch.attn_attend_prefix_cache:
assert not get_global_server_args().disable_chunked_prefix_cache
# MHA for chunked prefix kv cache when running model with MLA
assert forward_batch.prefix_chunk_idx is not None
assert forward_batch.prefix_chunk_cu_seq_lens is not None
assert forward_batch.prefix_chunk_max_seq_lens is not None
chunk_idx = forward_batch.prefix_chunk_idx
assert chunk_idx >= 0
assert forward_batch.mha_return_lse
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs,
)
o, _ = merge_state_v2_wrapper(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else:
o = result
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q,
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=forward_batch.mha_return_lse,
**kwargs,
)
if forward_batch.mha_return_lse:
output, lse, *rest = output
lse = torch.transpose(lse, 0, 1).contiguous()
return output, lse
return output.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else:
if (
forward_batch.attn_attend_prefix_cache is not None
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend(include_v2=True)
):
and not forward_batch.forward_mode.is_draft_extend()
):
# Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache:
assert not get_global_server_args().disable_chunked_prefix_cache
......@@ -843,39 +851,32 @@ class FlashAttentionBackend(AttentionBackend):
assert forward_batch.mha_return_lse
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling,
causal=False,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
**kwargs,
)
else:
# MHA for extend part of sequence without attending prefix kv cache
cu_seqlens_k = (
metadata.cu_seqlens_q
if not forward_batch.mha_one_shot
else metadata.cu_seqlens_k
)
max_seqlen_k = (
metadata.max_seq_len_q
if not forward_batch.mha_one_shot
else metadata.max_seq_len_k
)
output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=layer.scaling,
causal=True,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=forward_batch.mha_return_lse,
**kwargs,
)
......@@ -985,10 +986,16 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc
)
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
# if not self.use_mla:
if k_rope is None:
if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer,
......@@ -1030,7 +1037,8 @@ class FlashAttentionBackend(AttentionBackend):
if sinks is not None:
kwargs["sinks"] = sinks
k_descale, v_descale = None, None
# k_descale, v_descale = None, None
k_descale, v_descale = self.k_scale, self.v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
......@@ -1044,7 +1052,6 @@ class FlashAttentionBackend(AttentionBackend):
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
if not self.use_mla:
# Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
......@@ -1096,65 +1103,62 @@ class FlashAttentionBackend(AttentionBackend):
**kwargs,
)
else:
cu_seqlens_q = metadata.cu_seqlens_q
max_seqlen_q = metadata.max_seq_len_q
page_table = metadata.page_table
cache_seqlens = metadata.cache_seqlens_int32
cu_seqlens_k = metadata.cu_seqlens_k
max_seqlen_q = metadata.max_seq_len_q
q_reshaped = q.contiguous().view(
-1, layer.tp_q_head_num, layer.head_dim
cache_seqlens = metadata.cache_seqlens_int32
key_cache = key_cache.view(
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
)
# Default: single-token self-attention
result = flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
value_cache = value_cache.view(
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
)
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = (
flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling,
causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs,
)
)
o, _ = merge_state_v2(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
if layer.is_cross_attention:
page_table = metadata.encoder_page_table
cache_seqlens = metadata.encoder_lens_int32
cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1)
if max_seqlen_q > 1:
result = flash_attn_varlen_func(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
)
else:
o = result
result = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
)
o = result
else:
# Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
......
from flash_attn import (
flash_attn_varlen_func as flash_attn_varlen_func_interface,
flash_attn_with_kvcache as flash_attn_with_kvcache_interface
)
from typing import Optional, Union
import torch
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
qv=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[int, torch.Tensor]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
rotary_seqlens: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
attention_chunk: Optional[int] = None,
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
sinks=None,
ver=3,
):
return flash_attn_with_kvcache_interface(
q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]),
k_cache=k_cache.view(q.dtype),
v_cache=v_cache.view(q.dtype),
block_table=page_table,
cache_seqlens=cache_seqlens,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
return_softmax_lse=return_softmax_lse,
num_splits=num_splits,
)
def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q=None,
max_seqlen_k=None,
seqused_q=None,
seqused_k=None,
page_table=None,
softmax_scale=None,
causal=False,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
num_splits=1,
pack_gqa=None,
sm_margin=0,
return_softmax_lse=False,
sinks=None,
ver=3,
):
return flash_attn_varlen_func_interface(
q=q,
k=k.view(q.dtype),
v=v.view(q.dtype),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
return_attn_probs=return_softmax_lse,
softcap=softcap,
)
\ No newline at end of file
......@@ -16,6 +16,10 @@ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import get_bool_env_var
from sgl_kernel.flash_mla import dcu_create_flashmla_kv_indices
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
......@@ -79,7 +83,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
def init_forward_metadata(self, forward_batch: ForwardBatch):
use_sglang_create_flashmla_kv_indices_triton = get_bool_env_var("SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO")
bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle():
max_seqlen_pad = triton.cdiv(
......@@ -91,15 +95,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype=torch.int32,
device=forward_batch.seq_lens.device,
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
if use_sglang_create_flashmla_kv_indices_triton:
dcu_create_flashmla_kv_indices(
req_to_token_ptr = self.req_to_token,
req_pool_indices_ptr = forward_batch.req_pool_indices,
page_kernel_lens_ptr = forward_batch.seq_lens,
kv_start_idx = None,
kv_indices_ptr = block_kv_indices,
req_to_token_ptr_stride = self.req_to_token.stride(0),
kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32),
self.num_q_heads,
......@@ -121,15 +137,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype=torch.int32,
device=seq_lens.device,
)
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
if use_sglang_create_flashmla_kv_indices_triton:
dcu_create_flashmla_kv_indices(
req_to_token_ptr = self.req_to_token,
req_pool_indices_ptr = forward_batch.req_pool_indices,
page_kernel_lens_ptr = forward_batch.seq_lens,
kv_start_idx = None,
kv_indices_ptr = block_kv_indices,
req_to_token_ptr_stride = self.req_to_token.stride(0),
kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads,
......@@ -144,7 +172,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
)
else:
super().init_forward_metadata(forward_batch)
def init_cuda_graph_state(
self,
max_bs: int,
......
from __future__ import annotations
import warnings
import torch
from sglang.srt.utils import get_bool_env_var, direct_register_custom_op
_USE_OPT_CAT = get_bool_env_var("SGLANG_USE_OPT_CAT")
if _USE_OPT_CAT:
try:
from lightop import ds_cat # type: ignore
except ImportError: # pragma: no cover
ds_cat = None
warnings.warn(
"SGLANG_USE_OPT_CAT 已开启但无法导入 lightop.ds_cat,退回 torch.cat"
)
else:
ds_cat = None
# TODO: 单独注册有些问题
def ds_cat_wrapper(A: torch.Tensor,
B: torch.Tensor,
dim: int,
mode: int) -> torch.Tensor:
output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
ds_cat(A, B, C, mode)
return C
def ds_cat_fake(A: torch.Tensor,
B: torch.Tensor,
dim: int,
mode: int) -> torch.Tensor:
# 使用标准cat作为fake实现
return torch.cat([A, B], dim=dim)
direct_register_custom_op(
op_name="ds_cat",
op_func=ds_cat_wrapper,
mutates_args=[], # 没有修改参数,只有返回值
fake_impl=ds_cat_fake
)
def concat_decode_opt(A: torch.Tensor, B: torch.Tensor, dim: int):
assert dim == 2, "tensor dim must be 3 and concat dim must be 2"
mode = 0
if dim != 0:
return torch.ops.sglang.ds_cat(A, B, dim, mode)
assert False, "not support"
# def concat_decode_opt(A:torch.Tensor, B:torch.Tensor, dim:int):
# assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
# output_shape = list(A.shape)
# output_shape[dim] = A.shape[dim] + B.shape[dim]
# C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
# mode=0
# if dim!=0 :
# ds_cat(A, B, C, mode)
# return C
# assert False, "not support"
......@@ -47,7 +47,8 @@ if _is_hip:
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
else:
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_with_kvcache
from sglang.srt.layers.attention.flashattention_interface import flash_attn_with_kvcache
@dataclass(frozen=True)
......
......@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
from sgl_kernel import merge_state_v2
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache
class XPUAttentionBackend(AttentionBackend):
......
......@@ -160,21 +160,53 @@ class RMSNorm(CustomOp):
return output, residual_out
return rms_norm(x, self.weight.data, self.variance_epsilon)
# def forward_hip(
# self,
# x: torch.Tensor,
# residual: Optional[torch.Tensor] = None,
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# if not x.is_contiguous():
# # NOTE: Remove this if aiter kernel supports discontinuous input
# x = x.contiguous()
# if residual is not None:
# if _vllm_version < Version("0.9"):
# fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
# return x, residual
# else:
# residual_out = torch.empty_like(x)
# output = torch.empty_like(x)
# fused_add_rms_norm(
# output,
# x,
# residual_out,
# residual,
# self.weight.data,
# self.variance_epsilon,
# )
# return output, residual_out
# out = torch.empty_like(x)
# rms_norm(out, x, self.weight.data, self.variance_epsilon)
# return out
def forward_hip(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
):
if not x.is_contiguous():
# NOTE: Remove this if aiter kernel supports discontinuous input
x = x.contiguous()
if residual is not None:
if _vllm_version < Version("0.9"):
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
try:
fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual
else:
residual_out = torch.empty_like(x)
except TypeError:
output = torch.empty_like(x)
residual_out = torch.empty_like(x)
fused_add_rms_norm(
output,
x,
......@@ -184,10 +216,13 @@ class RMSNorm(CustomOp):
self.variance_epsilon,
)
return output, residual_out
out = torch.empty_like(x)
rms_norm(out, x, self.weight.data, self.variance_epsilon)
return out
def forward_native(
self,
x: torch.Tensor,
......
......@@ -45,6 +45,18 @@ _is_hip = is_hip()
_disable_hip_linear_quant = _is_hip and get_bool_env_var(
"SGLANG_ROCM_DISABLE_LINEARQUANT"
)
_use_fused_rms_quant = get_bool_env_var("SGLANG_USE_FUSED_RMS_QUANT")
_use_fused_silu_mul_quant = get_bool_env_var("SGLANG_USE_FUSED_SILU_MUL_QUANT")
if _use_fused_rms_quant:
try:
from lmslim.quantize.quant_ops import lm_faster_rmsquant
except Exception as e:
print(f"Error: Import fused rmsquant error: {e}")
if _use_fused_silu_mul_quant:
try:
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
except Exception as e:
print(f"Error: Import fused silu_mul_quant error: {e}")
logger = logging.getLogger(__name__)
......@@ -1360,7 +1372,7 @@ class RowParallelLinear(LinearBase):
# It does not support additional parameters.
param.load_row_parallel_weight(loaded_weight)
def forward(self, input_, skip_all_reduce=False):
def forward(self, input_, skip_all_reduce=False, use_fused_silu_mul_quant=False):
if self.input_is_parallel:
input_parallel = input_
else:
......@@ -1374,10 +1386,19 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
with use_symmetric_memory(
get_tp_group(), disabled=not is_allocation_symmetric()
):
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
if use_fused_silu_mul_quant:
xq, xs = lm_fuse_silu_mul_quant(input_parallel)
silu_quant_args = [xq, xs]
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel,
bias=bias_,
silu_quant_args=silu_quant_args
)
sm.tag(output_parallel)
else:
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel)
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
output = tensor_model_parallel_all_reduce(output_parallel)
......
......@@ -2,7 +2,6 @@ import logging
import torch
import triton
from sglang.srt.utils import ceil_div, is_cuda
logger = logging.getLogger(__name__)
......@@ -1015,196 +1014,133 @@ def zero_experts_compute_triton(
return output
from triton.language.extra import libdevice
from typing import Optional
@triton.jit
def compute_problem_sizes_w4a8_kernel(
masked_m_ptr,
problem_sizes1_ptr,
problem_sizes2_ptr,
n,
k,
num_experts,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = pid < num_experts
final_occurrences = tl.load(masked_m_ptr + pid, mask=mask, other=0)
ps1_idx_0 = pid * 3
ps1_idx_1 = ps1_idx_0 + 1
ps1_idx_2 = ps1_idx_0 + 2
ps2_idx_0 = pid * 3
ps2_idx_1 = ps2_idx_0 + 1
ps2_idx_2 = ps2_idx_0 + 2
ps1_mask_0 = ps1_idx_0 < num_experts * 3
ps1_mask_1 = ps1_idx_1 < num_experts * 3
ps1_mask_2 = ps1_idx_2 < num_experts * 3
ps2_mask_0 = ps2_idx_0 < num_experts * 3
ps2_mask_1 = ps2_idx_1 < num_experts * 3
ps2_mask_2 = ps2_idx_2 < num_experts * 3
tl.store(problem_sizes1_ptr + ps1_idx_0, 2 * n, mask=ps1_mask_0)
tl.store(problem_sizes1_ptr + ps1_idx_1, final_occurrences, mask=ps1_mask_1)
tl.store(problem_sizes1_ptr + ps1_idx_2, k, mask=ps1_mask_2)
tl.store(problem_sizes2_ptr + ps2_idx_0, k, mask=ps2_mask_0)
tl.store(problem_sizes2_ptr + ps2_idx_1, final_occurrences, mask=ps2_mask_1)
tl.store(problem_sizes2_ptr + ps2_idx_2, n, mask=ps2_mask_2)
def compute_problem_sizes_w4a8(
masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
):
BLOCK_SIZE = 256
grid = lambda meta: (triton.cdiv(num_experts, meta["BLOCK_SIZE"]),)
compute_problem_sizes_w4a8_kernel[grid](
masked_m,
problem_sizes1,
problem_sizes2,
n,
k,
num_experts,
BLOCK_SIZE=BLOCK_SIZE,
)
return problem_sizes1, problem_sizes2
def deepep_ll_get_cutlass_w4a8_moe_mm_data(
masked_m,
problem_sizes1,
problem_sizes2,
num_experts,
n,
k,
def _per_token_quant_int8_one_kernel_opt(
x_ptr,
xq_ptr,
scale_ptr,
stride_x,
stride_xq,
N,
T_dim,
tokens_per_expert_ptr,
BLOCK: tl.constexpr
):
problem_sizes1, problem_sizes2 = compute_problem_sizes_w4a8(
masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
)
return (
problem_sizes1.to(torch.int32),
problem_sizes2.to(torch.int32),
)
row_id = tl.program_id(0)
if tokens_per_expert_ptr is not None:
e = row_id // T_dim
t = row_id % T_dim
num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
return
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
@triton.jit
def _silu_and_mul_post_per_tensor_quant_kernel(
input_ptr,
stride_input_expert,
stride_input_token,
stride_input_dim,
output_ptr,
stride_output_expert,
stride_output_token,
stride_output_dim,
def _per_token_quant_int8_kernel_opt(
x_ptr,
xq_ptr,
scale_ptr,
masked_m_ptr,
inner_dim,
fp8_max,
fp8_min,
BLOCK_N: tl.constexpr,
NUM_STAGE: tl.constexpr,
stride_x,
stride_xq,
N,
E_dim,
T_dim,
tokens_per_expert_ptr,
BLOCK: tl.constexpr
):
"""
Triton kernel: fused SiLU(gate) * up + per-tensor FP8 quantization.
Shape:
input: [E, T_padded, 2*D] -> gate: [:,:,D], up: [:,:,D]
output: [E, T_padded, D], dtype=float8_e4m3fn
"""
expert_id = tl.program_id(2)
block_id_token = tl.program_id(1)
block_id_dim = tl.program_id(0)
num_token_blocks = tl.num_programs(1)
token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
scale = 1.0 / tl.load(scale_ptr).to(tl.float32)
stride_input_expert = tl.cast(stride_input_expert, tl.int32)
stride_output_expert = tl.cast(stride_output_expert, tl.int32)
stride_input_token = tl.cast(stride_input_token, tl.int32)
stride_output_token = tl.cast(stride_output_token, tl.int32)
offset_d = block_id_dim * BLOCK_N + tl.arange(0, BLOCK_N)
mask_d = offset_d < inner_dim
# base pointers for current expert and dim block
input_base_offs = input_ptr + expert_id * stride_input_expert + offset_d
output_base_offs = output_ptr + expert_id * stride_output_expert + offset_d
for token_idx in tl.range(
block_id_token, token_num_cur_expert, num_token_blocks, num_stages=NUM_STAGE
):
gate_ptr = input_base_offs + token_idx * stride_input_token
up_ptr = gate_ptr + inner_dim
gate = tl.load(gate_ptr, mask=mask_d, other=0.0).to(tl.float32)
up = tl.load(up_ptr, mask=mask_d, other=0.0).to(tl.float32)
# SiLU: x * sigmoid(x)
gate = gate / (1 + tl.exp(-gate))
gate = gate.to(input_ptr.dtype.element_ty)
gate_up = up * gate
scaled = gate_up * scale
output_q = tl.clamp(scaled, fp8_min, fp8_max).to(output_ptr.dtype.element_ty)
out_ptr = output_base_offs + token_idx * stride_output_token
tl.store(out_ptr, output_q, mask=mask_d)
def silu_and_mul_masked_post_per_tensor_quant_fwd(
input: torch.Tensor,
output: torch.Tensor,
masked_m: torch.Tensor,
scale: torch.Tensor,
) -> torch.Tensor:
"""
Fused SiLU + Mul + Per-Tensor Quantization to FP8.
Args:
input: [expert_num, token_num_padded, 2 * inner_dim]
output: [expert_num, token_num_padded, inner_dim], dtype=torch.float8_e4m3fn
masked_m: [expert_num], actual token count for each expert
scale: [1] or [expert_num], quantization scale (per-tensor or per-expert)
Returns:
output tensor
"""
assert input.is_contiguous()
assert output.is_contiguous()
assert output.dtype == torch.float8_e4m3fn
assert input.ndim == 3
assert input.shape[0] == masked_m.shape[0]
assert input.shape[-1] % 2 == 0
assert scale.numel() == 1 or scale.shape[0] == input.shape[0]
expert_num = input.shape[0]
# 3584
inner_dim = input.shape[-1] // 2
BLOCK_N = 256
BLOCK_M = 64 if expert_num < 4 else 32
NUM_STAGES = 3
hidden_dim_split_block_num = triton.cdiv(inner_dim, BLOCK_N)
grid = (hidden_dim_split_block_num, BLOCK_M, expert_num)
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = -fp8_max
_silu_and_mul_post_per_tensor_quant_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
scale,
masked_m,
inner_dim,
fp8_max,
fp8_min,
BLOCK_N=BLOCK_N,
NUM_STAGE=NUM_STAGES,
)
return output
token_idx_start = tl.program_id(0)
grid_size = tl.num_programs(0)
num_total_tokens = E_dim * T_dim
for token_idx in range(token_idx_start, num_total_tokens, grid_size):
is_valid_token = True
if tokens_per_expert_ptr is not None:
e = token_idx // T_dim
t = token_idx % T_dim
num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
is_valid_token = False
if is_valid_token:
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + token_idx * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + token_idx * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + token_idx, scale_x)
def per_token_quant_int8_triton_opt(x: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None):
if x.dim() != 3:
raise ValueError(f"Input must be 3D [E, T, H], but got {x.shape}")
E, T, H = x.shape
N = H
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1, ), device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
num_warps = min(max(BLOCK // 256, 1), 8)
if (E == 8 and T >= 1024) or (E == 16 and T >= 512):
num_warps = 1
num_tokens = E * T
grid_opt = num_tokens
if (E == 8 and T >= 1024) or (E == 16 and T >= 512):
grid_opt = max(1, num_tokens // (T // 256))
_per_token_quant_int8_kernel_opt[(grid_opt, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
E_dim=E,
T_dim=T,
tokens_per_expert_ptr=tokens_per_expert,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
else:
_per_token_quant_int8_one_kernel_opt[(grid_opt, )](
x,
x_q,
scales,
stride_x=x.stride(-2),
stride_xq=x_q.stride(-2),
N=N,
T_dim=T,
tokens_per_expert_ptr=tokens_per_expert,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=1,
)
return x_q, scales
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
from collections import defaultdict
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_marlin import SlimQuantCompressedTensorsMarlinConfig
from sglang.srt.layers.quantization.slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
import torch
import torch.distributed as dist
from sglang.srt import single_batch_overlap
from sglang.srt.layers import deep_gemm_wrapper
from sglang.srt.layers.moe import (
get_deepep_mode,
get_moe_a2a_backend,
get_moe_runner_backend,
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather,
ep_scatter,
silu_and_mul_masked_post_quant_fwd,
tma_align_input_scale,
per_token_quant_int8_triton_opt,
)
from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
from sglang.srt.layers.moe.token_dispatcher.deepep import (
......@@ -23,7 +33,8 @@ from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import get_bool_env_var, is_hip, is_npu
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu, direct_register_custom_op
from sglang.srt.utils.offloader import get_offloader
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (
......@@ -31,6 +42,8 @@ if TYPE_CHECKING:
DeepEPNormalDispatchOutput,
DispatchOutput,
)
from lightop import m_grouped_w4a8_gemm_nt_masked, fuse_silu_mul_quant_ep, m_grouped_w8a8_gemm_nt_masked, m_grouped_w8a8_gemm_nt_contig_asm, fuse_silu_mul_quant
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
_is_hip = is_hip()
_is_npu = is_npu()
......@@ -46,8 +59,400 @@ if _use_aiter:
logger = logging.getLogger(__name__)
#------ custom op for lightop
def m_grouped_w4a8_gemm_nt_masked_wrapper(
a0: torch.Tensor, a1: torch.Tensor,
b0: torch.Tensor, b1: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m_per_group: int
) -> torch.Tensor:
return m_grouped_w4a8_gemm_nt_masked(
(a0, a1),
(b0, b1),
d,
masked_m,
expected_m_per_group,
config={"MODE": 1000,}
)
def m_grouped_w4a8_gemm_nt_masked_fake(
a0: torch.Tensor, a1: torch.Tensor,
b0: torch.Tensor, b1: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m_per_group: int
) -> torch.Tensor:
return d
def m_grouped_w8a8_gemm_nt_masked_wrapper(
a0: torch.Tensor, a1: torch.Tensor,
b0: torch.Tensor, b1: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m_per_group: int
) -> torch.Tensor:
return m_grouped_w8a8_gemm_nt_masked(
(a0, a1),
(b0, b1),
d,
masked_m,
expected_m_per_group,
config={"MODE": 1000,}
)
def m_grouped_w8a8_gemm_nt_masked_fake(
a0: torch.Tensor, a1: torch.Tensor,
b0: torch.Tensor, b1: torch.Tensor,
d: torch.Tensor,
masked_m: torch.Tensor,
expected_m_per_group: int
) -> torch.Tensor:
return d
def fuse_silu_mul_quant_ep_wrapper(
input: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None,
num_local_tokens_tensor: Optional[torch.Tensor] = None,
topk:int=1,
expect_m:int=-1) -> tuple[torch.Tensor, torch.Tensor]:
return fuse_silu_mul_quant_ep(
input,
tokens_per_expert,
num_local_tokens_tensor,
topk,
expect_m
)
def fuse_silu_mul_quant_ep_fake(
input: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor] = None,
num_local_tokens_tensor: Optional[torch.Tensor] = None,
topk:int=1,
expect_m:int=-1) -> tuple[torch.Tensor, torch.Tensor]:
E, T, H = input.shape
d = H // 2
output = torch.empty(E, T, d, dtype=torch.int8, device=input.device)
scales = torch.empty((E, T, 1),
device=input.device,
dtype=torch.float32)
return output, scales
direct_register_custom_op(
op_name="m_grouped_w4a8_gemm_nt_masked",
op_func=m_grouped_w4a8_gemm_nt_masked_wrapper,
mutates_args=[],
fake_impl=m_grouped_w4a8_gemm_nt_masked_fake
)
direct_register_custom_op(
op_name="m_grouped_w8a8_gemm_nt_masked",
op_func=m_grouped_w8a8_gemm_nt_masked_wrapper,
mutates_args=[],
fake_impl=m_grouped_w8a8_gemm_nt_masked_fake
)
direct_register_custom_op(
op_name="fuse_silu_mul_quant_ep",
op_func=fuse_silu_mul_quant_ep_wrapper,
mutates_args=[],
fake_impl=fuse_silu_mul_quant_ep_fake
)
#------
# TODO(kaixih@nvidia): ideally we should merge this logic into
# `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@torch.compile
def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
temp = x.to(torch.float32).view(torch.int32)
exp = torch.bitwise_right_shift(temp, 23)
mant = torch.bitwise_and(temp, 0x7FFFFF)
is_ru = torch.logical_and(
torch.logical_and((mant > 0), (exp != 0xFE)),
~torch.logical_and((exp == 0), (mant <= 0x400000)),
)
exp = torch.where(is_ru, exp + 1, exp)
new_x = exp.to(torch.uint8).view(torch.int)
return new_x.transpose(1, 2).contiguous().transpose(1, 2)
class EPMoE(FusedMoE):
"""
MoE Expert Parallel Impl
"""
def __init__(
self,
num_experts: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
layer_id: int,
num_fused_shared_experts: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
activation: str = "silu",
routed_scaling_factor: Optional[float] = None,
gemm1_alpha: Optional[float] = None,
gemm1_clamp_limit: Optional[float] = None,
with_bias: bool = False,
):
super().__init__(
num_experts=num_experts,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_fused_shared_experts=num_fused_shared_experts,
layer_id=layer_id,
top_k=top_k,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
activation=activation,
# apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor,
gemm1_alpha=gemm1_alpha,
gemm1_clamp_limit=gemm1_clamp_limit,
with_bias=with_bias,
)
self.intermediate_size = intermediate_size
if isinstance(quant_config, Fp8Config):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.use_fp8_w8a8 = True
self.fp8_dtype = torch.float8_e4m3fn
self.activation_scheme = quant_config.activation_scheme
self.use_w4a8_marlin = False
self.use_w8a8_marlin = False
elif isinstance(quant_config, SlimQuantW4A8Int8MarlinConfig):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.use_fp8_w8a8 = False
self.activation_scheme = None
self.use_w4a8_marlin = True
self.use_w8a8_marlin = False
elif isinstance(quant_config, SlimQuantCompressedTensorsMarlinConfig):
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
self.block_shape = (
self.quant_method.quant_config.weight_block_size
if self.use_block_quant
else None
)
self.use_fp8_w8a8 = False
self.activation_scheme = None
self.use_w4a8_marlin = False
self.use_w8a8_marlin = True
else:
self.use_fp8_w8a8 = False
self.use_block_quant = False
self.block_shape = None
self.activation_scheme = None
self.use_w4a8_marlin = False
self.use_w8a8_marlin = False
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm(hidden_states, topk_output)
else:
return super().forward(hidden_states, topk_output)
def forward_deepgemm(
self,
hidden_states: torch.Tensor,
topk_output: TopKOutput,
):
self.w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
self.w2_weight_fp8 = (
self.w2_weight,
self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
)
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
hidden_states_shape = hidden_states.shape
hidden_states_dtype = hidden_states.dtype
hidden_states_device = hidden_states.device
topk_weights, topk_ids, _ = topk_output
if not self.use_block_quant:
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
scale_block_size = 128
w13_weight_scale_n = 2 * (
(self.intermediate_size + scale_block_size - 1) // scale_block_size
)
w13_weight_scale_k = (
hidden_states_shape[-1] + scale_block_size - 1
) // scale_block_size
w13_weight_scale = (
self.w13_weight_scale.unsqueeze(1)
.repeat_interleave(w13_weight_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w13_weight_scale_k, dim=2)
)
self.w13_weight_fp8 = (
self.w13_weight,
w13_weight_scale,
)
w2_weight_scale_n = (
hidden_states_shape[-1] + scale_block_size - 1
) // scale_block_size
w2_weight_scale_k = (
self.intermediate_size + scale_block_size - 1
) // scale_block_size
w2_weight_scale = (
self.w2_weight_scale.unsqueeze(1)
.repeat_interleave(w2_weight_scale_n, dim=1)
.unsqueeze(2)
.repeat_interleave(w2_weight_scale_k, dim=2)
)
self.w2_weight_fp8 = (
self.w2_weight,
w2_weight_scale,
)
# PreReorder
m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
moe_ep_deepgemm_preprocess(
topk_ids,
self.num_experts,
hidden_states,
self.top_k,
self.start_expert_id,
self.end_expert_id,
self.block_shape,
)
)
dispose_tensor(hidden_states)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
b, s_mn, s_k = gateup_input_scale.shape
assert (
s_mn % 4 == 0 and s_k % 4 == 0
), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
# GroupGemm-0
gateup_input_fp8 = (
gateup_input,
(
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
gateup_input_scale
)
),
)
num_groups, m, k = gateup_input_fp8[0].size()
n = self.w13_weight.size(1)
gateup_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
gateup_input_fp8,
self.w13_weight_fp8,
gateup_output,
masked_m,
expected_m,
)
del gateup_input
del gateup_input_fp8
# Act
down_input = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2,
),
device=hidden_states_device,
dtype=self.fp8_dtype,
)
scale_block_size = 128
down_input_scale = torch.empty(
(
gateup_output.shape[0],
gateup_output.shape[1],
gateup_output.shape[2] // 2 // scale_block_size,
),
device=hidden_states_device,
dtype=torch.float32,
)
silu_and_mul_masked_post_quant_fwd(
gateup_output,
down_input,
down_input_scale,
scale_block_size,
masked_m,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del gateup_output
# GroupGemm-1
n = self.w2_weight.size(1)
down_input_fp8 = (
down_input,
(
down_input_scale
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
),
)
down_output = torch.empty(
(num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
down_input_fp8,
self.w2_weight_fp8,
down_output,
masked_m,
expected_m,
)
del down_input
del down_input_fp8
# PostReorder
output = torch.empty(
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
)
post_reorder_triton_kernel[(hidden_states_shape[0],)](
down_output,
output,
src2dst,
topk_ids,
topk_weights,
self.start_expert_id,
self.end_expert_id,
self.top_k,
hidden_states_shape[1],
m_max * self.start_expert_id,
BLOCK_SIZE=512,
)
if self.moe_runner_config.routed_scaling_factor is not None:
output *= self.moe_runner_config.routed_scaling_factor
return output
class DeepEPMoE(FusedMoE):
class DeepEPMoE(EPMoE):
"""
MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
Mooncake EP shares the same class, as they expose the same interface.
......@@ -112,11 +517,28 @@ class DeepEPMoE(FusedMoE):
self.deepep_mode = get_deepep_mode()
if self.deepep_mode.enable_low_latency() and not _is_npu:
# NPU supports low_latency deepep without deepgemm
assert (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
# TODO: move to the beginning of the file
from sglang.srt.distributed.parallel_state import get_tp_group
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
group=get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=self.num_experts,
num_local_experts=self.num_local_experts,
hidden_size=hidden_size,
params_dtype=params_dtype,
deepep_mode=self.deepep_mode,
async_finish=True, # TODO
return_recv_hook=True,
)
# if self.deepep_mode.enable_low_latency() and not _is_npu:
# # NPU supports low_latency deepep without deepgemm
# assert (
# deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
# ), f"DeepEP {self.deepep_mode} mode requires deep_gemm"
if _use_aiter:
# expert_mask is of size (self.num_local_experts + 1),
# the extra 1 is for invalid rank_id (in original deepep, the invalid rank_id is -1, but aiter does not allow -1, we use a mask to make those ids invalid)
......@@ -130,6 +552,23 @@ class DeepEPMoE(FusedMoE):
)
# the last one is invalid rank_id
self.expert_mask[:-1] = 1
# elif not _is_npu:
# self.w13_weight_fp8 = (
# self.w13_weight,
# (
# self.w13_weight_scale_inv
# if self.use_block_quant
# else self.w13_weight_scale
# ),
# )
# self.w2_weight_fp8 = (
# self.w2_weight,
# (
# self.w2_weight_scale_inv
# if self.use_block_quant
# else self.w2_weight_scale
# ),
# )
def forward(
self,
......@@ -189,35 +628,39 @@ class DeepEPMoE(FusedMoE):
output = self.forward_aiter(dispatch_output)
elif _is_npu:
assert DispatchOutputChecker.format_is_deepep(dispatch_output)
output = self.forward_npu(dispatch_output)
elif DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
if self.use_w4afp8:
output = self.forward_cutlass_w4afp8(dispatch_output)
return self.forward_npu(dispatch_output)
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output):
#assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
return self.forward_deepgemm_contiguous(dispatch_output)
elif self.use_w4a8_marlin:
return self.forward_deepgemm_w4a8_marlin_contiguous(dispatch_output)
elif self.use_w8a8_marlin:
return self.forward_groupgemm_w8a8_marlin_contiguous(dispatch_output)
else:
assert False, "forward_deepgemm_contiguous is deprecated"
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
if (
get_moe_runner_backend().is_flashinfer_cutedsl()
and self.quant_config.get_name() == "modelopt_fp4"
):
output = self.forward_flashinfer_cutedsl(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
raise ValueError(
f"Dispatch output is not supported"
)
elif self.use_w4afp8:
output = self.forward_cutlass_w4afp8_masked(dispatch_output)
elif DispatchOutputChecker.format_is_deepep_ll(dispatch_output):
if self.use_w4a8_marlin:
return self.forward_groupgemm_w4a8_marlin_masked(dispatch_output)
elif self.use_w8a8_marlin:
return self.forward_groupgemm_w8a8_marlin_masked(dispatch_output)
else:
assert False, "forward_deepgemm_masked is deprecated"
combine_input_wrapper = (
DeepEPNormalCombineInput
if DispatchOutputChecker.format_is_deepep_normal(dispatch_output)
else DeepEPLLCombineInput
)
return combine_input_wrapper(
hidden_states=output,
topk_ids=dispatch_output.topk_ids,
topk_weights=dispatch_output.topk_weights,
)
if (
get_moe_runner_backend().is_flashinfer_cutedsl()
and self.quant_config.get_name() == "modelopt_fp4"
):
return self.forward_flashinfer_cutedsl(
dispatch_output, down_gemm_overlap_args=down_gemm_overlap_args
)
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
assert down_gemm_overlap_args is None
return self.forward_deepgemm_masked(dispatch_output)
else:
raise ValueError(
f"Dispatch output format {dispatch_output.format} is not supported"
)
def combine(
self,
......@@ -267,6 +710,292 @@ class DeepEPMoE(FusedMoE):
expert_mask=self.expert_mask,
)
def forward_deepgemm_w4a8_marlin_contiguous(
self,
dispatch_output: DeepEPNormalOutput,
):
hidden_states, hidden_states_scale, topk_idx, topk_weights, num_recv_tokens_per_expert = (
dispatch_output
)
#hidden_states_int8, hidden_states_scale = hidden_states_int8
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
all_tokens = sum(num_recv_tokens_per_expert)
if all_tokens <= 0:
return hidden_states.bfloat16()
expert_output = self.quant_method.apply_ep(
x=hidden_states,
w1=self.w13_weight,
w2=self.w2_weight,
topk_ids=topk_idx,
topk_weights=topk_weights,
global_num_experts=self.moe_runner_config.num_experts,
expert_map=self.expert_map,
activation=self.moe_runner_config.activation,
apply_router_weight_on_input=self.moe_runner_config.apply_router_weight_on_input,
use_nn_moe=False,
w1_scale=self.w13_weight_scale,
w2_scale=self.w2_weight_scale,
routed_scaling_factor = self.moe_runner_config.routed_scaling_factor,
)
return expert_output
def forward_groupgemm_w8a8_marlin_contiguous(
self,
dispatch_output: DeepEPNormalOutput,
):
hidden_states, hidden_states_scale, topk_idx, topk_weights, num_recv_tokens_per_expert = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
all_tokens = sum(num_recv_tokens_per_expert)
if all_tokens <= 0:
return hidden_states.bfloat16()
device = hidden_states.device
M = hidden_states.shape[0]
K = hidden_states.shape[1]
topk = topk_idx.shape[1]
active_experts = set()
token_expert_pos = [None] * M
for t in range(M):
lst = []
for pos in range(topk):
e = int(topk_idx[t, pos].item())
if e >= 0:
lst.append((e, pos))
active_experts.add(e)
token_expert_pos[t] = lst
active_experts = sorted(list(active_experts))
num_active = len(active_experts)
if num_active == 0:
return hidden_states.bfloat16()
counts = defaultdict(int)
for t in range(M):
for (e, pos) in token_expert_pos[t]:
counts[e] += 1
per_expert_block = {}
for e in active_experts:
cnt = counts.get(e, 0)
if cnt <= 0:
per_expert_block[e] = 0
else:
needed = ((cnt + 256 - 1) // 256) * 256 # next multiple of 256
per_expert_block[e] = max(256, needed)
expert_slot_offset = {}
offset = 0
for e in active_experts:
expert_slot_offset[e] = offset
offset += per_expert_block[e]
pad_M = offset
hidden_states_packed = torch.zeros((pad_M, K), device=device, dtype=hidden_states.dtype)
m_indices = torch.full((pad_M,), -1, device=device, dtype=torch.int32)
slot_counters = {e: 0 for e in active_experts}
token_row_weight_list = {t: [] for t in range(M)}
for t in range(M):
for (e, pos) in token_expert_pos[t]:
start = expert_slot_offset[e]
slot = slot_counters[e]
if slot >= per_expert_block[e]:
raise RuntimeError(f"Internal error: expert {e} slot {slot} >= block {per_expert_block[e]}")
row = start + slot
hidden_states_packed[row] = hidden_states[t]
m_indices[row] = int(e)
slot_counters[e] += 1
w = topk_weights[t, pos].to(device=device)
w_f = w.float() if w.dtype != torch.float32 else w
token_row_weight_list[t].append((row, w_f))
q_a1_all, q_a1_scale = per_token_quant_int8(hidden_states_packed)
N = self.w13_weight.size(1)
gateup_output = torch.empty((pad_M, N * 16), device=device, dtype=torch.bfloat16)
m_grouped_w8a8_gemm_nt_contig_asm(
(q_a1_all, q_a1_scale),
(self.w13_weight, self.w13_weight_scale),
gateup_output,
m_indices,
)
q_a2_all, q_a2_scale = fuse_silu_mul_quant(gateup_output)
down_output = torch.empty((pad_M, K), device=device, dtype=torch.bfloat16)
down_output = m_grouped_w8a8_gemm_nt_contig_asm(
(q_a2_all, q_a2_scale),
(self.w2_weight, self.w2_weight_scale),
down_output,
m_indices,
)
result = torch.zeros((M, K), device=device, dtype=down_output.dtype)
for t in range(M):
pairs = token_row_weight_list[t]
if not pairs:
continue
acc = None
for (row, w) in pairs:
vec = down_output[row].float()
weighted = vec * w
acc = weighted if acc is None else (acc + weighted)
result[t] = acc.to(result.dtype)
return result
def forward_deepgemm_contiguous(
self,
dispatch_output: DeepEPNormalOutput,
):
(
hidden_states,
hidden_states_scale,
topk_ids,
topk_weights,
num_recv_tokens_per_expert,
) = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
if num_recv_tokens_per_expert is None:
return hidden_states.bfloat16()
all_tokens = sum(num_recv_tokens_per_expert)
if all_tokens <= 0:
return hidden_states.bfloat16()
M, K = hidden_states.size()
N = self.w13_weight.size(1)
scale_block_size = 128
w13_weight_fp8 = (
self.w13_weight,
(
self.w13_weight_scale_inv
if self.use_block_quant
else self.w13_weight_scale
),
)
w2_weight_fp8 = (
self.w2_weight,
(
self.w2_weight_scale_inv
if self.use_block_quant
else self.w2_weight_scale
),
)
hidden_states_shape = hidden_states.shape
hidden_states_device = hidden_states.device
hidden_states_dtype = hidden_states.dtype
input_tensor = [
torch.empty(
(all_tokens, K),
device=hidden_states.device,
dtype=hidden_states.dtype,
),
(
# TODO check whether need `zeros`
torch.zeros(
(ceil_div(K // 128, 4), all_tokens),
device=hidden_states.device,
dtype=torch.int,
).transpose(0, 1)
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
else torch.empty(
(all_tokens, K // 128),
device=hidden_states.device,
dtype=torch.float32,
)
),
]
m_indices = torch.empty(
all_tokens, device=hidden_states.device, dtype=torch.int32
)
output_index = torch.empty_like(topk_ids)
if get_offloader().forbid_copy_engine_usage:
num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
num_recv_tokens_per_expert
)
else:
num_recv_tokens_per_expert_gpu = torch.tensor(
num_recv_tokens_per_expert,
dtype=torch.int32,
pin_memory=True,
device="cpu",
).cuda(non_blocking=True)
expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
ep_scatter(
hidden_states,
hidden_states_scale,
topk_ids,
num_recv_tokens_per_expert_gpu,
expert_start_loc,
input_tensor[0],
input_tensor[1],
m_indices,
output_index,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
dispose_tensor(hidden_states)
gateup_output = torch.empty(
(all_tokens, N),
device=hidden_states_device,
dtype=torch.bfloat16,
)
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
input_tensor[1] = tma_align_input_scale(input_tensor[1])
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
input_tensor, w13_weight_fp8, gateup_output, m_indices
)
del input_tensor
down_input = torch.empty(
(
all_tokens,
N // 2,
),
device=gateup_output.device,
dtype=torch.bfloat16,
)
silu_and_mul(gateup_output.view(-1, N), down_input)
del gateup_output
down_output = torch.empty(
(all_tokens, K),
device=hidden_states_device,
dtype=torch.bfloat16,
)
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
down_input,
scale_block_size,
column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
)
del down_input
if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
down_input_scale = tma_align_input_scale(down_input_scale)
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
(down_input_fp8, down_input_scale),
w2_weight_fp8,
down_output,
m_indices,
)
del down_input_fp8, down_input_scale
gather_out = torch.empty(
hidden_states_shape,
device=hidden_states_device,
dtype=torch.bfloat16,
)
ep_gather(down_output, topk_ids, topk_weights, output_index, gather_out)
return gather_out
def forward_flashinfer_cutedsl(
self,
dispatch_output: DeepEPLLDispatchOutput,
......@@ -296,7 +1025,106 @@ class DeepEPMoE(FusedMoE):
dispatch_output=dispatch_output,
)
def forward_cutlass_w4afp8_masked(
def forward_groupgemm_w4a8_marlin_masked(
self,
dispatch_output: DeepEPLLOutput,
):
hidden_states, _, _, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
# base shapes
num_groups, m, k = hidden_states.size()
expected_m = min(m, expected_m)
# ---- first quant: ensure float input for quantizer ----
q_a1_all, q_a1_scale = per_token_quant_int8_triton_opt(hidden_states, masked_m)
# ---- weights & scales ----
w13_weight = self.w13_weight
w13_scales = self.w13_weight_scale
w2_weight = self.w2_weight
w2_scales = self.w2_weight_scale
n1 = w13_scales.size(1)
gateup_output = torch.empty((num_groups, m, n1), device=hidden_states.device, dtype=torch.bfloat16)
# ---- first GEMM ----
torch.ops.sglang.m_grouped_w4a8_gemm_nt_masked(
q_a1_all, q_a1_scale,
w13_weight, w13_scales,
gateup_output,
masked_m,
expected_m,
)
q_a2_all, q_a2_scale = torch.ops.sglang.fuse_silu_mul_quant_ep(gateup_output, masked_m)
# ---- second GEMM ----
n2 = w2_scales.size(1)
down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16)
torch.ops.sglang.m_grouped_w4a8_gemm_nt_masked(
q_a2_all, q_a2_scale,
w2_weight, w2_scales,
down_output,
masked_m,
expected_m,
)
return down_output
def forward_groupgemm_w8a8_marlin_masked(
self,
dispatch_output: DeepEPLLOutput,
):
hidden_states, _, topk_ids, _, masked_m, expected_m = dispatch_output
assert self.quant_method is not None
assert self.moe_runner_config.activation == "silu"
# base shapes
num_groups, m, k = hidden_states.size()
expected_m = min(m, expected_m)
# ---- first quant: ensure float input for quantizer ----
q_a1_all, q_a1_scale = per_token_quant_int8_triton_opt(hidden_states, masked_m)
# ---- weights & scales ----
w13_weight = self.w13_weight
w13_scales = self.w13_weight_scale
w2_weight = self.w2_weight
w2_scales = self.w2_weight_scale
n1 = w13_scales.size(1)
gateup_output = torch.empty((num_groups, m, n1), device=hidden_states.device, dtype=torch.bfloat16)
# ---- first GEMM ----
torch.ops.sglang.m_grouped_w8a8_gemm_nt_masked(
q_a1_all, q_a1_scale,
w13_weight, w13_scales,
gateup_output,
masked_m,
expected_m,
)
q_a2_all, q_a2_scale = torch.ops.sglang.fuse_silu_mul_quant_ep(gateup_output, masked_m)
# ---- second GEMM ----
n2 = w2_scales.size(1)
down_output = torch.empty((num_groups, m, n2), device=q_a2_all.device, dtype=torch.bfloat16)
torch.ops.sglang.m_grouped_w8a8_gemm_nt_masked(
q_a2_all, q_a2_scale,
w2_weight, w2_scales,
down_output,
masked_m,
expected_m,
)
return down_output
def forward_deepgemm_masked(
self,
dispatch_output: DeepEPLLDispatchOutput,
):
......
......@@ -65,7 +65,7 @@ def inplace_fused_experts(
topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
activation: str = "silu",
activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
......@@ -84,6 +84,8 @@ def inplace_fused_experts(
gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
) -> None:
if isinstance(activation, int):
activation = "silu" if activation == 0 else "gelu"
fused_experts_impl(
hidden_states,
w1,
......@@ -123,7 +125,7 @@ def inplace_fused_experts_fake(
topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
activation: str = "silu",
activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
......@@ -161,7 +163,7 @@ def outplace_fused_experts(
topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
activation: str = "silu",
activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
......@@ -181,6 +183,8 @@ def outplace_fused_experts(
gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
) -> torch.Tensor:
if isinstance(activation, int):
activation = "silu" if activation == 0 else "gelu"
return fused_experts_impl(
hidden_states,
w1,
......@@ -220,7 +224,7 @@ def outplace_fused_experts_fake(
topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
activation: str = "silu",
activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
......@@ -273,9 +277,12 @@ def fused_experts(
block_shape: Optional[List[int]] = None,
):
topk_weights, topk_ids, _ = topk_output
filter_expert = (
moe_runner_config.num_experts is None
or moe_runner_config.num_experts != moe_runner_config.num_local_experts
act_id = (
0 if (
moe_runner_config.activation == 0
or (isinstance(moe_runner_config.activation, str)
and moe_runner_config.activation.lower() == "silu")
) else 1
)
if moe_runner_config.inplace:
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
......@@ -287,7 +294,7 @@ def fused_experts(
topk_ids,
b1,
b2,
moe_runner_config.activation,
act_id,
moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8,
use_int8_w8a8,
......@@ -316,7 +323,7 @@ def fused_experts(
topk_ids,
b1,
b2,
moe_runner_config.activation,
act_id,
moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8,
use_int8_w8a8,
......@@ -366,7 +373,7 @@ def fused_experts_impl(
b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None,
inplace: bool = False,
activation: str = "silu",
activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False,
......@@ -386,6 +393,8 @@ def fused_experts_impl(
gemm1_limit: Optional[float] = None,
filter_expert: bool = True,
):
if isinstance(activation, int):
activation = "silu" if activation == 0 else "gelu"
padded_size = padding_size
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
padded_size = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment