Commit a1175a4e authored by maxiao1's avatar maxiao1
Browse files

Merge remote-tracking branch 'origin/v0.5.4_dev' into sglang_v0.5.5

parents 0c006b88 31653dd9
...@@ -848,10 +848,12 @@ class BenchmarkMetrics: ...@@ -848,10 +848,12 @@ class BenchmarkMetrics:
mean_ttft_ms: float mean_ttft_ms: float
median_ttft_ms: float median_ttft_ms: float
std_ttft_ms: float std_ttft_ms: float
p95_ttft_ms: float
p99_ttft_ms: float p99_ttft_ms: float
mean_tpot_ms: float mean_tpot_ms: float
median_tpot_ms: float median_tpot_ms: float
std_tpot_ms: float std_tpot_ms: float
p95_tpot_ms: float
p99_tpot_ms: float p99_tpot_ms: float
mean_itl_ms: float mean_itl_ms: float
median_itl_ms: float median_itl_ms: float
...@@ -1721,10 +1723,12 @@ def calculate_metrics( ...@@ -1721,10 +1723,12 @@ def calculate_metrics(
* 1000, # ttfts is empty if streaming is not supported by backend * 1000, # ttfts is empty if streaming is not supported by backend
median_ttft_ms=np.median(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000,
std_ttft_ms=np.std(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000,
p95_ttft_ms=np.percentile(ttfts or 0, 95) * 1000,
p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000,
mean_tpot_ms=np.mean(tpots or 0) * 1000, mean_tpot_ms=np.mean(tpots or 0) * 1000,
median_tpot_ms=np.median(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000,
std_tpot_ms=np.std(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000,
p95_tpot_ms=np.percentile(tpots or 0, 95) * 1000,
p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000,
mean_itl_ms=np.mean(itls or 0) * 1000, mean_itl_ms=np.mean(itls or 0) * 1000,
median_itl_ms=np.median(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000,
...@@ -2052,6 +2056,12 @@ async def benchmark( ...@@ -2052,6 +2056,12 @@ async def benchmark(
print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms))
print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms)) print("{:<40} {:<10.2f}".format("Median TTFT (ms):", metrics.median_ttft_ms))
print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms))
print("{:<40} {:<10.2f}".format("P95 TTFT (ms):", metrics.p95_ttft_ms))
print("{s:{c}^{n}}".format(s="Time per Output Token (excl. 1st token)", n=50, c="-"))
print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms))
print("{:<40} {:<10.2f}".format("Median TPOT (ms):", metrics.median_tpot_ms))
print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms))
print("{:<40} {:<10.2f}".format("P95 TPOT (ms):", metrics.p95_tpot_ms))
print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-")) print("{s:{c}^{n}}".format(s="Inter-Token Latency", n=50, c="-"))
print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms))
print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms))
......
...@@ -4,10 +4,24 @@ from typing import List, Optional, Tuple ...@@ -4,10 +4,24 @@ from typing import List, Optional, Tuple
import torch import torch
from sglang.srt.utils import is_hip, is_hpu, is_npu from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu
try:
from lmslim import quant_ops
from lmslim import quant_tools
except Exception:
print("INFO: Please install lmslim if you want to infer gptq or awq or w8a8 model.\n")
try:
import lightop
except Exception:
print("INFO: Please install lightop if you want to infer awq of marlin.\n")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = get_bool_env_var(
"USE_VLLM_CUSTOM_ALLREDUCE", default="false"
)
use_dcu_custom_allreduce = get_bool_env_var(
"USE_DCU_CUSTOM_ALLREDUCE", default="true"
)
if not is_hpu(): if not is_hpu():
try: try:
...@@ -15,6 +29,11 @@ if not is_hpu(): ...@@ -15,6 +29,11 @@ if not is_hpu():
except ImportError as e: except ImportError as e:
logger.warning("Failed to import from custom_ar with %r", e) logger.warning("Failed to import from custom_ar with %r", e)
if use_dcu_custom_allreduce:
try:
import vllm._C
except ImportError as e:
logger.warning("Failed to import from vllm._C with %r", e)
if not is_hip() and not is_npu(): if not is_hip() and not is_npu():
custom_op = sgl_kernel.allreduce custom_op = sgl_kernel.allreduce
...@@ -54,8 +73,79 @@ if not is_hip() and not is_npu(): ...@@ -54,8 +73,79 @@ if not is_hip() and not is_npu():
) -> None: ) -> None:
custom_op.register_graph_buffers(fa, handles, offsets) custom_op.register_graph_buffers(fa, handles, offsets)
elif is_hip and use_dcu_custom_allreduce:
# custom ar
def init_custom_ar(ipc_tensors: list[torch.Tensor], rank_data: torch.Tensor,
rank: int, fully_connected: bool) -> int:
return torch.ops._C_custom_ar.init_custom_ar(ipc_tensors, rank_data, rank,
fully_connected)
def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int,
reg_buffer_sz_bytes: int) -> None:
torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer,
reg_buffer_sz_bytes)
def dispose(fa: int) -> None:
torch.ops._C_custom_ar.dispose(fa)
def meta_size() -> int:
return torch.ops._C_custom_ar.meta_size()
def register_buffer(fa: int, ipc_tensors: list[int]) -> None:
return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
def get_graph_buffer_ipc_meta(fa: int) -> tuple[list[int], list[int]]:
return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
def register_graph_buffers(fa: int, handles: list[list[int]],
offsets: list[list[int]]) -> None:
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
def allocate_shared_buffer_and_handle(size: int) -> tuple[int, torch.Tensor]:
return torch.ops._C_custom_ar.allocate_shared_buffer_and_handle(size)
def open_mem_handle(mem_handle: torch.Tensor):
return torch.ops._C_custom_ar.open_mem_handle(mem_handle)
def free_shared_buffer(ptr: int) -> None:
torch.ops._C_custom_ar.free_shared_buffer(ptr)
def read_cache(
keys: torch.Tensor,
values: torch.Tensor,
key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache_dtype: str
) -> None:
torch.ops._C_cache_ops.read_cache(keys, values, key_caches,
value_caches, slot_mapping,
kv_cache_dtype)
def write_cache_multi_layers(
keys: torch.Tensor,
values: torch.Tensor,
key_caches: list[torch.Tensor],
value_caches: list[torch.Tensor],
slot_mapping: torch.Tensor,
kv_cache_dtype: str
) -> None:
torch.ops._C_cache_ops.write_cache_multi_layers(keys, values, key_caches,
value_caches, slot_mapping,
kv_cache_dtype)
else: else:
# ROCM custom allreduce # sgl_kernel ROCM custom allreduce
def init_custom_ar( def init_custom_ar(
meta: torch.Tensor, meta: torch.Tensor,
...@@ -163,3 +253,83 @@ def mscclpp_allreduce( ...@@ -163,3 +253,83 @@ def mscclpp_allreduce(
context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int context: int, inp: torch.Tensor, out: torch.Tensor, nthreads: int, nblocks: int
) -> None: ) -> None:
return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks) return sgl_kernel.allreduce.mscclpp_allreduce(context, inp, out, nthreads, nblocks)
def triton_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None,
best_config:Optional[list] = None) -> torch.Tensor:
return quant_ops.triton_scaled_mm(a, b,scale_a,scale_b,out_dtype,bias,best_config)
def cutlass_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
`cutlass_scaled_mm` implements a fused version of
`output = torch.mm((scale_a * a), (scale_b * b)).to(out_dtype)`
where scale_a * a and scale_b * b are implemented using numpy-style
broadcasting.
In order to support blockwise scaling like found in DeepSeek V3 we also
support extended "group" broadcast rules. We extend the numpy-style
broadcasting rules with the following rule:
"if the extent of a dimension in the source shape is between 1 and
corresponding extent in the target shape we repeat each element along
that dimension src_shape[dim] // target_shape[dim] times consecutively"
example if we have:
a = [[1, 2], and target_shape = (2, 4)
[3, 4]]
then we would expand a to:
a = [[1, 1, 2, 2],
[3, 3, 4, 4]]
currently we only support the case:
scale_a.shape * [1, 128] == a.shape
scale_b.shape * [128, 128] == b.shape
"""
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == b.shape[
1] and bias.dtype == out_dtype
# m = a.shape[0]
# n = b.shape[1]
# cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
# if current_platform.is_rocm() or not cutlass_compatible_b:
# from vllm.model_executor.layers.quantization.compressed_tensors.triton_scaled_mm import ( # noqa
# triton_scaled_mm)
# return triton_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
# out = torch.empty((m, n), dtype=out_dtype, device=a.device)
# torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
# return out
#return quant_ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def rocblas_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def triton_int8_gemm_helper(m: int,
n: int,
k: int,
per_token_act_quant: bool,
per_out_channel_weight_quant: bool,
use_bias: bool,
out_dtype: type[torch.dtype] = torch.float16,
device: str = "cuda:0",
best_config:Optional[list] = None,
repeat:Optional[int] = 2):
return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config,repeat)
\ No newline at end of file
...@@ -635,7 +635,9 @@ class ModelConfig: ...@@ -635,7 +635,9 @@ class ModelConfig:
"petit_nvfp4", "petit_nvfp4",
"quark", "quark",
"mxfp4", "mxfp4",
"auto-round", "slimquant_w4a8_marlin",
"w8a8_int8",
"slimquant_marlin",
] ]
optimized_quantization_methods = [ optimized_quantization_methods = [
"fp8", "fp8",
...@@ -655,6 +657,8 @@ class ModelConfig: ...@@ -655,6 +657,8 @@ class ModelConfig:
"qoq", "qoq",
"w4afp8", "w4afp8",
"petit_nvfp4", "petit_nvfp4",
"slimquant_w4a8_marlin",
"slimquant_marlin",
] ]
compatible_quantization_methods = { compatible_quantization_methods = {
"modelopt_fp8": ["modelopt"], "modelopt_fp8": ["modelopt"],
......
...@@ -34,6 +34,21 @@ except ImportError: ...@@ -34,6 +34,21 @@ except ImportError:
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
try:
if ops.use_vllm_custom_allreduce and not _is_hip:
# Use vLLM custom allreduce
ops.meta_size()
elif ops.use_dcu_custom_allreduce:
ops.meta_size()
else:
# Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
import sgl_kernel # noqa: F401
custom_ar = True
except Exception:
# For CPUs
custom_ar = False
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -416,3 +431,274 @@ class CustomAllreduce: ...@@ -416,3 +431,274 @@ class CustomAllreduce:
def __del__(self): def __del__(self):
self.close() self.close()
class DCUCustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8, 16]
# max_size: max supported allreduce size
def __init__(self,
group: ProcessGroup,
device: Union[int, str, torch.device],
max_size=8192 * 512) -> None:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self._IS_CAPTURING = False
self.disabled = True
if not custom_ar:
# disable because of missing custom allreduce library
# e.g. in a non-GPU environment
logger.info("Custom allreduce is disabled because "
"of missing custom allreduce library")
return
self.group = group
assert dist.get_backend(group) != dist.Backend.NCCL, (
"CustomAllreduce should be attached to a non-NCCL group.")
if not all(in_the_same_node_as(group, source_rank=0)):
# No need to initialize custom allreduce for multi-node case.
logger.warning(
"Custom allreduce is disabled because this process group"
" spans across nodes.")
return
rank = dist.get_rank(group=self.group)
self.rank = rank
world_size = dist.get_world_size(group=self.group)
# if world_size > envs.VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX:
if world_size > 16:
return
if world_size == 1:
# No need to initialize custom allreduce for single GPU case.
return
if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES:
logger.warning(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.",
world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES))
return
if isinstance(device, int):
device = torch.device(f"cuda:{device}")
elif isinstance(device, str):
device = torch.device(device)
# now `device` is a `torch.device` object
assert isinstance(device, torch.device)
self.device = device
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if cuda_visible_devices:
device_ids = list(map(int, cuda_visible_devices.split(",")))
else:
device_ids = list(range(torch.cuda.device_count()))
physical_device_id = device_ids[device.index]
tensor = torch.tensor([physical_device_id],
dtype=torch.int,
device="cpu")
gather_list = [
torch.tensor([0], dtype=torch.int, device="cpu")
for _ in range(world_size)
]
dist.all_gather(gather_list, tensor, group=self.group)
physical_device_ids = [t.item() for t in gather_list]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
# assert current_platform.is_cuda_alike()
# fully_connected = current_platform.is_fully_connected(
# physical_device_ids)
if _is_cuda or _is_hip:
fully_connected = is_full_nvlink(physical_device_ids, world_size)
# if world_size > 2 and not fully_connected:
if not fully_connected:
max_size = 32 * 8192 * 2
# if not envs.VLLM_PCIE_USE_CUSTOM_ALLREDUCE:
# logger.warning(
# "Custom allreduce is disabled because it's not supported on"
# " more than two PCIe-only GPUs. To silence this warning, "
# "specify disable_custom_all_reduce=True explicitly.")
# return
logger.warning(
"We are using PCIe's custom allreduce."
"If the performance is poor, we can add "
"--disable-custom-all-reduce in the instruction.")
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if not _is_hip and not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly.")
return
self.disabled = False
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
self.meta_ptrs = self.create_shared_buffer(ops.meta_size() + max_size,
group=group,
uncached=True)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# is enough for 131072 such tuples. The largest model I've seen only
# needs less than 10000 of registered tuples.
self.rank_data = torch.empty(8 * 1024 * 1024,
dtype=torch.uint8,
device=self.device)
self.max_size = max_size
self.rank = rank
self.world_size = world_size
self.fully_connected = fully_connected
self._ptr = ops.init_custom_ar(self.meta_ptrs, self.rank_data, rank,
self.fully_connected)
ops.register_buffer(self._ptr, self.buffer_ptrs)
@contextmanager
def capture(self):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try:
self._IS_CAPTURING = True
yield
finally:
self._IS_CAPTURING = False
if not self.disabled:
self.register_graph_buffers()
def register_graph_buffers(self):
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
logger.info("Registering %d cuda graph addresses", len(offset))
# We cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
all_data = [[None, None]
for _ in range(dist.get_world_size(group=self.group))]
all_data[self.rank] = [handle, offset]
ranks = sorted(dist.get_process_group_ranks(group=self.group))
for i, rank in enumerate(ranks):
dist.broadcast_object_list(all_data[i],
src=rank,
group=self.group,
device="cpu")
# Unpack list of tuples to tuple of lists.
handles = [d[0] for d in all_data] # type: ignore
offsets = [d[1] for d in all_data] # type: ignore
ops.register_graph_buffers(self._ptr, handles, offsets)
def should_custom_ar(self, inp: torch.Tensor):
if self.disabled:
return False
inp_size = inp.numel() * inp.element_size()
# custom allreduce requires input byte size to be multiples of 16
if inp_size % 16 != 0:
return False
if not is_weak_contiguous(inp):
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
return inp_size <= self.max_size
def all_reduce(self,
inp: torch.Tensor,
*,
out: torch.Tensor = None,
registered: bool = False):
"""Performs an out-of-place all reduce.
If registered is True, this assumes inp's pointer is already
IPC-registered. Otherwise, inp is first copied into a pre-registered
buffer.
"""
if out is None:
out = torch.empty_like(inp)
if registered:
ops.all_reduce(self._ptr, inp, out, 0, 0)
else:
ops.all_reduce(self._ptr, inp, out, self.buffer_ptrs[self.rank],
self.max_size)
return out
def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
"""The main allreduce API that provides support for cuda graph."""
# When custom allreduce is disabled, this will be None.
if self.disabled or not self.should_custom_ar(input):
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
return self.all_reduce(input, registered=False)
else:
# If warm up, mimic the allocation pattern since custom
# allreduce is out-of-place.
return torch.empty_like(input)
else:
# Note: outside of cuda graph context, custom allreduce incurs a
# cost of cudaMemcpy, which should be small (<=1% of overall
# latency) compared to the performance gain of using custom kernels
return self.all_reduce(input, registered=False)
def close(self):
if not self.disabled and self._ptr:
if ops is not None:
ops.dispose(self._ptr)
self._ptr = 0
self.free_shared_buffer(self.meta_ptrs, rank=self.rank)
self.free_shared_buffer(self.buffer_ptrs, rank=self.rank)
def __del__(self):
self.close()
@staticmethod
def create_shared_buffer(size_in_bytes: int,
group: Optional[ProcessGroup] = None,
uncached: Optional[bool] = False) -> list[int]:
pointer, handle = ops.allocate_shared_buffer_and_handle(size_in_bytes)
world_size = dist.get_world_size(group=group)
rank = dist.get_rank(group=group)
handles = [None] * world_size
dist.all_gather_object(handles, handle, group=group)
pointers: list[int] = []
for i, h in enumerate(handles):
if i == rank:
pointers.append(pointer) # type: ignore
else:
pointers.append(ops.open_mem_handle(h))
return pointers
@staticmethod
def free_shared_buffer(pointers: list[int],
group: Optional[ProcessGroup] = None,
rank: Optional[int] = 0) -> None:
if rank is None:
rank = dist.get_rank(group=group)
if ops is not None:
ops.free_shared_buffer(pointers[rank])
...@@ -54,6 +54,7 @@ from sglang.srt.utils import ( ...@@ -54,6 +54,7 @@ from sglang.srt.utils import (
is_xpu, is_xpu,
supports_custom_op, supports_custom_op,
) )
from sglang.srt import _custom_ops as ops
_is_npu = is_npu() _is_npu = is_npu()
_is_cpu = is_cpu() _is_cpu = is_cpu()
...@@ -327,7 +328,7 @@ class GroupCoordinator: ...@@ -327,7 +328,7 @@ class GroupCoordinator:
# Lazy import to avoid documentation build error # Lazy import to avoid documentation build error
from sglang.srt.distributed.device_communicators.custom_all_reduce import ( from sglang.srt.distributed.device_communicators.custom_all_reduce import (
CustomAllreduce, CustomAllreduce, DCUCustomAllreduce
) )
from sglang.srt.distributed.device_communicators.pymscclpp import ( from sglang.srt.distributed.device_communicators.pymscclpp import (
PyMscclppCommunicator, PyMscclppCommunicator,
...@@ -371,10 +372,17 @@ class GroupCoordinator: ...@@ -371,10 +372,17 @@ class GroupCoordinator:
if use_custom_allreduce and self.world_size > 1: if use_custom_allreduce and self.world_size > 1:
# Initialize a custom fast all-reduce implementation. # Initialize a custom fast all-reduce implementation.
try: try:
self.ca_comm = CustomAllreduce( if is_hip() and ops.use_dcu_custom_allreduce:
group=self.cpu_group, self.ca_comm = DCUCustomAllreduce(
device=self.device, group=self.cpu_group,
) device=self.device,
)
else:
self.ca_comm = CustomAllreduce(
group=self.cpu_group,
device=self.device,
max_size=ca_max_size,
)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f"Setup Custom allreduce failed with {e}. To silence this " f"Setup Custom allreduce failed with {e}. To silence this "
......
...@@ -188,6 +188,17 @@ class Envs: ...@@ -188,6 +188,17 @@ class Envs:
SGLANG_USE_AITER = EnvBool(False) SGLANG_USE_AITER = EnvBool(False)
SGLANG_ROCM_FUSED_DECODE_MLA = EnvBool(False) SGLANG_ROCM_FUSED_DECODE_MLA = EnvBool(False)
SGLANG_ROCM_DISABLE_LINEARQUANT = EnvBool(False) SGLANG_ROCM_DISABLE_LINEARQUANT = EnvBool(False)
# DCU Lightop
SGLANG_USE_LIGHTOP = EnvBool(False)
# Fused
SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD = EnvBool(False)
SGLANG_USE_OPT_CAT = EnvBool(False)
SGLANG_USE_FUSED_RMS_QUANT = EnvBool(False)
SGLANG_USE_FUSED_SILU_MUL_QUANT = EnvBool(False)
# Quantization # Quantization
SGLANG_INT4_WEIGHT = EnvBool(False) SGLANG_INT4_WEIGHT = EnvBool(False)
......
...@@ -99,7 +99,6 @@ def create_triton_backend(runner): ...@@ -99,7 +99,6 @@ def create_triton_backend(runner):
return TritonAttnBackend(runner) return TritonAttnBackend(runner)
@register_attention_backend("torch_native") @register_attention_backend("torch_native")
def create_torch_native_backend(runner): def create_torch_native_backend(runner):
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
...@@ -120,6 +119,11 @@ def create_flashmla_backend(runner): ...@@ -120,6 +119,11 @@ def create_flashmla_backend(runner):
return FlashMLABackend(runner) return FlashMLABackend(runner)
@register_attention_backend("dcu_mla")
def create_dcu_mla_backend(runner):
from sglang.srt.layers.attention.dcu_mla_backend import DCUMLABackend
return DCUMLABackend(runner)
@register_attention_backend("fa3") @register_attention_backend("fa3")
def create_flashattention_v3_backend(runner): def create_flashattention_v3_backend(runner):
......
This diff is collapsed.
...@@ -9,7 +9,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple ...@@ -9,7 +9,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache # from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache
from sgl_kernel.sparse_flash_attn import ( from sgl_kernel.sparse_flash_attn import (
convert_vertical_slash_indexes, convert_vertical_slash_indexes,
convert_vertical_slash_indexes_mergehead, convert_vertical_slash_indexes_mergehead,
......
...@@ -20,7 +20,8 @@ if TYPE_CHECKING: ...@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sgl_kernel import merge_state_v2 from sgl_kernel import merge_state_v2
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache # from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache
@dataclass @dataclass
...@@ -328,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -328,6 +329,8 @@ class FlashAttentionBackend(AttentionBackend):
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
self.is_hybrid = model_runner.is_hybrid self.is_hybrid = model_runner.is_hybrid
self.k_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device)
self.v_scale = torch.tensor([1.0], dtype=torch.float32, device=self.device)
if self.is_hybrid: if self.is_hybrid:
self.full_to_swa_index_mapping = ( self.full_to_swa_index_mapping = (
model_runner.token_to_kv_pool.full_to_swa_index_mapping model_runner.token_to_kv_pool.full_to_swa_index_mapping
...@@ -596,9 +599,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -596,9 +599,11 @@ class FlashAttentionBackend(AttentionBackend):
forward_batch.req_pool_indices, : metadata.max_seq_len_k forward_batch.req_pool_indices, : metadata.max_seq_len_k
] ]
if any( if (
forward_batch.extend_prefix_lens_cpu any(forward_batch.extend_prefix_lens_cpu)
) or forward_batch.forward_mode.is_draft_extend(include_v2=True): or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND_V2 #nhb
):
extend_seq_lens = forward_batch.extend_seq_lens extend_seq_lens = forward_batch.extend_seq_lens
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
metadata.cu_seqlens_q = torch.nn.functional.pad( metadata.cu_seqlens_q = torch.nn.functional.pad(
...@@ -608,10 +613,13 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -608,10 +613,13 @@ class FlashAttentionBackend(AttentionBackend):
metadata.max_seq_len_q = metadata.max_seq_len_k metadata.max_seq_len_q = metadata.max_seq_len_k
metadata.cu_seqlens_q = metadata.cu_seqlens_k metadata.cu_seqlens_q = metadata.cu_seqlens_k
# Setup local attention if enabled # # Setup local attention if enabled
if forward_batch.forward_mode == ForwardMode.EXTEND: # if forward_batch.forward_mode == ForwardMode.EXTEND:
# self._init_local_attn_metadata(forward_batch, metadata, device)
if forward_batch.forward_mode in (ForwardMode.EXTEND, ForwardMode.DRAFT_EXTEND_V2):
self._init_local_attn_metadata(forward_batch, metadata, device) self._init_local_attn_metadata(forward_batch, metadata, device)
# Encoder metadata for cross attention # Encoder metadata for cross attention
if forward_batch.encoder_lens is not None: if forward_batch.encoder_lens is not None:
assert ( assert (
...@@ -668,10 +676,11 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -668,10 +676,11 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc else forward_batch.encoder_out_cache_loc
) )
if not self.use_mla: if k_rope is None:
forward_batch.token_to_kv_pool.set_kv_buffer( forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale layer, cache_loc, k, v, #layer.k_scale, layer.v_scale
) )
else: else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer( forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, layer,
...@@ -690,7 +699,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -690,7 +699,8 @@ class FlashAttentionBackend(AttentionBackend):
layer.sliding_window_size is not None and layer.sliding_window_size > -1 layer.sliding_window_size is not None and layer.sliding_window_size > -1
) )
window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1) window_size = (layer.sliding_window_size, 0) if is_swa else (-1, -1)
k_descale, v_descale = None, None # k_descale, v_descale = None, None
k_descale, v_descale = self.k_scale, self.v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None, # has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case, # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
...@@ -704,7 +714,7 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -704,7 +714,7 @@ class FlashAttentionBackend(AttentionBackend):
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape) k_descale = layer.k_scale.expand(descale_shape)
v_descale = layer.v_scale.expand(descale_shape) v_descale = layer.v_scale.expand(descale_shape)
q = q.to(self.kv_cache_dtype) # q = q.to(self.kv_cache_dtype)
q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None q_rope = q_rope.to(self.kv_cache_dtype) if q_rope is not None else None
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
causal = True causal = True
...@@ -774,61 +784,59 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -774,61 +784,59 @@ class FlashAttentionBackend(AttentionBackend):
cu_seqlens_k = metadata.encoder_cu_seqlens_k cu_seqlens_k = metadata.encoder_cu_seqlens_k
window_size = (-1, -1) window_size = (-1, -1)
result = flash_attn_with_kvcache( if forward_batch.attn_attend_prefix_cache:
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), assert not get_global_server_args().disable_chunked_prefix_cache
k_cache=key_cache, # MHA for chunked prefix kv cache when running model with MLA
v_cache=value_cache, assert forward_batch.prefix_chunk_idx is not None
page_table=page_table, assert forward_batch.prefix_chunk_cu_seq_lens is not None
cache_seqlens=cache_seqlens, assert forward_batch.prefix_chunk_max_seq_lens is not None
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None, chunk_idx = forward_batch.prefix_chunk_idx
max_seqlen_q=max_seqlen_q, assert chunk_idx >= 0
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal, assert forward_batch.mha_return_lse
window_size=window_size, output = flash_attn_varlen_func(
softcap=layer.logit_cap, q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k_descale=k_descale, k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v_descale=v_descale, v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
return_softmax_lse=use_cascade_attn, cu_seqlens_q=metadata.cu_seqlens_q,
num_splits=self.num_splits, cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
**kwargs, max_seqlen_q=metadata.max_seq_len_q,
) max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
if use_cascade_attn:
o, softmax_lse, *rest = result
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=self.forward_metadata_spec_decode_expand.page_table,
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=False, causal=False,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale, k_descale=k_descale,
v_descale=v_descale, v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
num_splits=self.num_splits,
**kwargs, **kwargs,
) )
o, _ = merge_state_v2_wrapper(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
)
else: else:
o = result output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=metadata.cu_seqlens_q,
max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=metadata.max_seq_len_q,
softmax_scale=layer.scaling,
causal=True,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=forward_batch.mha_return_lse,
**kwargs,
)
if forward_batch.mha_return_lse:
output, lse, *rest = output
lse = torch.transpose(lse, 0, 1).contiguous()
return output, lse
return output.view(-1, layer.tp_q_head_num * layer.v_head_dim)
else: else:
if ( if (
forward_batch.attn_attend_prefix_cache is not None forward_batch.attn_attend_prefix_cache is not None
and not forward_batch.forward_mode.is_target_verify() and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend(include_v2=True) and not forward_batch.forward_mode.is_draft_extend()
): ):
# Do multi-head attention with chunked prefix cache # Do multi-head attention with chunked prefix cache
if forward_batch.attn_attend_prefix_cache: if forward_batch.attn_attend_prefix_cache:
assert not get_global_server_args().disable_chunked_prefix_cache assert not get_global_server_args().disable_chunked_prefix_cache
...@@ -843,39 +851,32 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -843,39 +851,32 @@ class FlashAttentionBackend(AttentionBackend):
assert forward_batch.mha_return_lse assert forward_batch.mha_return_lse
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim), q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx], cu_seqlens_k=forward_batch.prefix_chunk_cu_seq_lens[chunk_idx],
max_seqlen_q=metadata.max_seq_len_q, max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx], max_seqlen_k=forward_batch.prefix_chunk_max_seq_lens[chunk_idx],
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=False, causal=False,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=True, return_softmax_lse=True,
**kwargs, **kwargs,
) )
else: else:
# MHA for extend part of sequence without attending prefix kv cache
cu_seqlens_k = (
metadata.cu_seqlens_q
if not forward_batch.mha_one_shot
else metadata.cu_seqlens_k
)
max_seqlen_k = (
metadata.max_seq_len_q
if not forward_batch.mha_one_shot
else metadata.max_seq_len_k
)
output = flash_attn_varlen_func( output = flash_attn_varlen_func(
q=q.view(-1, layer.tp_q_head_num, layer.head_dim), q=q.view(-1, layer.tp_q_head_num, layer.head_dim),
k=k.view(-1, layer.tp_k_head_num, layer.head_dim).to(q.dtype), k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).to(q.dtype), v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_q=metadata.cu_seqlens_q, cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k, cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=metadata.max_seq_len_q, max_seqlen_q=metadata.max_seq_len_q,
max_seqlen_k=max_seqlen_k, max_seqlen_k=max_seqlen_k,
softmax_scale=layer.scaling, softmax_scale=layer.scaling,
causal=True, causal=True,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=forward_batch.mha_return_lse, return_softmax_lse=forward_batch.mha_return_lse,
**kwargs, **kwargs,
) )
...@@ -985,10 +986,16 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -985,10 +986,16 @@ class FlashAttentionBackend(AttentionBackend):
if not layer.is_cross_attention if not layer.is_cross_attention
else forward_batch.encoder_out_cache_loc else forward_batch.encoder_out_cache_loc
) )
if not self.use_mla: # if not self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer( if k_rope is None:
layer, cache_loc, k, v, layer.k_scale, layer.v_scale if not self.use_mla:
) forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v
)
else: else:
forward_batch.token_to_kv_pool.set_mla_kv_buffer( forward_batch.token_to_kv_pool.set_mla_kv_buffer(
layer, layer,
...@@ -1030,7 +1037,8 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1030,7 +1037,8 @@ class FlashAttentionBackend(AttentionBackend):
if sinks is not None: if sinks is not None:
kwargs["sinks"] = sinks kwargs["sinks"] = sinks
k_descale, v_descale = None, None # k_descale, v_descale = None, None
k_descale, v_descale = self.k_scale, self.v_scale
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None, # has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case. # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
...@@ -1044,7 +1052,6 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1044,7 +1052,6 @@ class FlashAttentionBackend(AttentionBackend):
k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None k_rope = k_rope.to(self.kv_cache_dtype) if k_rope is not None else None
if not self.use_mla: if not self.use_mla:
# Do multi-head attention # Do multi-head attention
key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer( key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id layer.layer_id
) )
...@@ -1096,65 +1103,62 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -1096,65 +1103,62 @@ class FlashAttentionBackend(AttentionBackend):
**kwargs, **kwargs,
) )
else: else:
cu_seqlens_q = metadata.cu_seqlens_q
max_seqlen_q = metadata.max_seq_len_q
page_table = metadata.page_table page_table = metadata.page_table
cache_seqlens = metadata.cache_seqlens_int32
cu_seqlens_k = metadata.cu_seqlens_k cu_seqlens_k = metadata.cu_seqlens_k
max_seqlen_q = metadata.max_seq_len_q cache_seqlens = metadata.cache_seqlens_int32
q_reshaped = q.contiguous().view( key_cache = key_cache.view(
-1, layer.tp_q_head_num, layer.head_dim -1, self.page_size, layer.tp_k_head_num, layer.head_dim
) )
value_cache = value_cache.view(
# Default: single-token self-attention -1, self.page_size, layer.tp_v_head_num, layer.head_dim
result = flash_attn_with_kvcache(
q=q_reshaped,
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=metadata.cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=False if use_cascade_attn else causal,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
) )
if use_cascade_attn: if layer.is_cross_attention:
o, softmax_lse, *rest = result page_table = metadata.encoder_page_table
o_expand, softmax_lse_expand, *rest_expand = ( cache_seqlens = metadata.encoder_lens_int32
flash_attn_with_kvcache( cu_seqlens_k = metadata.encoder_cu_seqlens_k
q=q_reshaped, window_size = (-1, -1)
k_cache=key_cache, if max_seqlen_q > 1:
v_cache=value_cache, result = flash_attn_varlen_func(
page_table=self.forward_metadata_spec_decode_expand.page_table, q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32, k=k.view(-1, layer.tp_k_head_num, layer.head_dim).view(q.dtype),
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q, v=v.view(-1, layer.tp_k_head_num, layer.v_head_dim).view(q.dtype),
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k, cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q, cu_seqlens_k=cu_seqlens_k,
softmax_scale=layer.scaling, max_seqlen_q=max_seqlen_q,
causal=False, max_seqlen_k=max_seqlen_q,
window_size=window_size, softmax_scale=layer.scaling,
softcap=layer.logit_cap, causal=True,
k_descale=k_descale, window_size=window_size,
v_descale=v_descale, softcap=layer.logit_cap,
return_softmax_lse=True, k_descale=k_descale,
num_splits=self.num_splits, v_descale=v_descale,
**kwargs, return_softmax_lse=use_cascade_attn,
) num_splits=self.num_splits,
) **kwargs,
o, _ = merge_state_v2(
o,
softmax_lse.T.contiguous(),
o_expand,
softmax_lse_expand.T.contiguous(),
) )
else: else:
o = result result = flash_attn_with_kvcache(
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache=key_cache,
v_cache=value_cache,
page_table=page_table,
cache_seqlens=cache_seqlens,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
max_seqlen_q=max_seqlen_q,
softmax_scale=layer.scaling,
causal=True,
window_size=window_size,
softcap=layer.logit_cap,
k_descale=k_descale,
v_descale=v_descale,
return_softmax_lse=use_cascade_attn,
num_splits=self.num_splits,
**kwargs,
)
o = result
else: else:
# Do absorbed multi-latent attention # Do absorbed multi-latent attention
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to( kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).to(
......
from flash_attn import (
flash_attn_varlen_func as flash_attn_varlen_func_interface,
flash_attn_with_kvcache as flash_attn_with_kvcache_interface
)
from typing import Optional, Union
import torch
def flash_attn_with_kvcache(
q,
k_cache,
v_cache,
k=None,
v=None,
qv=None,
rotary_cos=None,
rotary_sin=None,
cache_seqlens: Optional[Union[int, torch.Tensor]] = None,
cache_batch_idx: Optional[torch.Tensor] = None,
cache_leftpad: Optional[torch.Tensor] = None,
page_table: Optional[torch.Tensor] = None,
cu_seqlens_q: Optional[torch.Tensor] = None,
cu_seqlens_k_new: Optional[torch.Tensor] = None,
max_seqlen_q: Optional[int] = None,
rotary_seqlens: Optional[torch.Tensor] = None,
q_descale: Optional[torch.Tensor] = None,
k_descale: Optional[torch.Tensor] = None,
v_descale: Optional[torch.Tensor] = None,
softmax_scale=None,
causal=False,
window_size=(-1, -1), # -1 means infinite context window
attention_chunk: Optional[int] = None,
softcap=0.0, # 0.0 means deactivated
rotary_interleaved=True,
scheduler_metadata=None,
num_splits=0, # Can be tuned for speed
pack_gqa=None, # Can be tuned for speed
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
sinks=None,
ver=3,
):
return flash_attn_with_kvcache_interface(
q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]),
k_cache=k_cache.view(q.dtype),
v_cache=v_cache.view(q.dtype),
block_table=page_table,
cache_seqlens=cache_seqlens,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
softcap=softcap,
return_softmax_lse=return_softmax_lse,
num_splits=num_splits,
)
def flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q=None,
max_seqlen_k=None,
seqused_q=None,
seqused_k=None,
page_table=None,
softmax_scale=None,
causal=False,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=(-1, -1),
attention_chunk=0,
softcap=0.0,
num_splits=1,
pack_gqa=None,
sm_margin=0,
return_softmax_lse=False,
sinks=None,
ver=3,
):
return flash_attn_varlen_func_interface(
q=q,
k=k.view(q.dtype),
v=v.view(q.dtype),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=softmax_scale,
causal=causal,
return_attn_probs=return_softmax_lse,
softcap=softcap,
)
\ No newline at end of file
...@@ -16,6 +16,10 @@ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton ...@@ -16,6 +16,10 @@ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import get_bool_env_var
from sgl_kernel.flash_mla import dcu_create_flashmla_kv_indices
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -79,7 +83,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -79,7 +83,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
use_sglang_create_flashmla_kv_indices_triton = get_bool_env_var("SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO")
bs = forward_batch.batch_size bs = forward_batch.batch_size
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
max_seqlen_pad = triton.cdiv( max_seqlen_pad = triton.cdiv(
...@@ -91,15 +95,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -91,15 +95,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype=torch.int32, dtype=torch.int32,
device=forward_batch.seq_lens.device, device=forward_batch.seq_lens.device,
) )
create_flashmla_kv_indices_triton[(bs,)]( if use_sglang_create_flashmla_kv_indices_triton:
self.req_to_token, dcu_create_flashmla_kv_indices(
forward_batch.req_pool_indices, req_to_token_ptr = self.req_to_token,
forward_batch.seq_lens, req_pool_indices_ptr = forward_batch.req_pool_indices,
None, page_kernel_lens_ptr = forward_batch.seq_lens,
block_kv_indices, kv_start_idx = None,
self.req_to_token.stride(0), kv_indices_ptr = block_kv_indices,
max_seqlen_pad, req_to_token_ptr_stride = self.req_to_token.stride(0),
) kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
forward_batch.seq_lens.to(torch.int32), forward_batch.seq_lens.to(torch.int32),
self.num_q_heads, self.num_q_heads,
...@@ -121,15 +137,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -121,15 +137,27 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
dtype=torch.int32, dtype=torch.int32,
device=seq_lens.device, device=seq_lens.device,
) )
create_flashmla_kv_indices_triton[(bs,)]( if use_sglang_create_flashmla_kv_indices_triton:
self.req_to_token, dcu_create_flashmla_kv_indices(
forward_batch.req_pool_indices, req_to_token_ptr = self.req_to_token,
seq_lens, req_pool_indices_ptr = forward_batch.req_pool_indices,
None, page_kernel_lens_ptr = forward_batch.seq_lens,
block_kv_indices, kv_start_idx = None,
self.req_to_token.stride(0), kv_indices_ptr = block_kv_indices,
max_seqlen_pad, req_to_token_ptr_stride = self.req_to_token.stride(0),
) kv_indices_ptr_stride = max_seqlen_pad,
)
else:
create_flashmla_kv_indices_triton[(bs,)](
self.req_to_token,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
None,
block_kv_indices,
self.req_to_token.stride(0),
max_seqlen_pad,
)
mla_metadata, num_splits = get_mla_metadata( mla_metadata, num_splits = get_mla_metadata(
seq_lens.to(torch.int32), seq_lens.to(torch.int32),
self.num_draft_tokens * self.num_q_heads, self.num_draft_tokens * self.num_q_heads,
...@@ -144,7 +172,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): ...@@ -144,7 +172,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
) )
else: else:
super().init_forward_metadata(forward_batch) super().init_forward_metadata(forward_batch)
def init_cuda_graph_state( def init_cuda_graph_state(
self, self,
max_bs: int, max_bs: int,
......
from __future__ import annotations
import warnings
import torch
from sglang.srt.utils import get_bool_env_var, direct_register_custom_op
_USE_OPT_CAT = get_bool_env_var("SGLANG_USE_OPT_CAT")
if _USE_OPT_CAT:
try:
from lightop import ds_cat # type: ignore
except ImportError: # pragma: no cover
ds_cat = None
warnings.warn(
"SGLANG_USE_OPT_CAT 已开启但无法导入 lightop.ds_cat,退回 torch.cat"
)
else:
ds_cat = None
# TODO: 单独注册有些问题
def ds_cat_wrapper(A: torch.Tensor,
B: torch.Tensor,
dim: int,
mode: int) -> torch.Tensor:
output_shape = list(A.shape)
output_shape[dim] = A.shape[dim] + B.shape[dim]
C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
ds_cat(A, B, C, mode)
return C
def ds_cat_fake(A: torch.Tensor,
B: torch.Tensor,
dim: int,
mode: int) -> torch.Tensor:
# 使用标准cat作为fake实现
return torch.cat([A, B], dim=dim)
direct_register_custom_op(
op_name="ds_cat",
op_func=ds_cat_wrapper,
mutates_args=[], # 没有修改参数,只有返回值
fake_impl=ds_cat_fake
)
def concat_decode_opt(A: torch.Tensor, B: torch.Tensor, dim: int):
assert dim == 2, "tensor dim must be 3 and concat dim must be 2"
mode = 0
if dim != 0:
return torch.ops.sglang.ds_cat(A, B, dim, mode)
assert False, "not support"
# def concat_decode_opt(A:torch.Tensor, B:torch.Tensor, dim:int):
# assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
# output_shape = list(A.shape)
# output_shape[dim] = A.shape[dim] + B.shape[dim]
# C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
# mode=0
# if dim!=0 :
# ds_cat(A, B, C, mode)
# return C
# assert False, "not support"
...@@ -47,7 +47,8 @@ if _is_hip: ...@@ -47,7 +47,8 @@ if _is_hip:
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
) )
else: else:
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache # from sgl_kernel.flash_attn import flash_attn_with_kvcache
from sglang.srt.layers.attention.flashattention_interface import flash_attn_with_kvcache
@dataclass(frozen=True) @dataclass(frozen=True)
......
...@@ -20,7 +20,8 @@ if TYPE_CHECKING: ...@@ -20,7 +20,8 @@ if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sgl_kernel import merge_state_v2 from sgl_kernel import merge_state_v2
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache # from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
from sglang.srt.layers.attention.flashattention_interface import flash_attn_varlen_func, flash_attn_with_kvcache
class XPUAttentionBackend(AttentionBackend): class XPUAttentionBackend(AttentionBackend):
......
...@@ -160,21 +160,53 @@ class RMSNorm(CustomOp): ...@@ -160,21 +160,53 @@ class RMSNorm(CustomOp):
return output, residual_out return output, residual_out
return rms_norm(x, self.weight.data, self.variance_epsilon) return rms_norm(x, self.weight.data, self.variance_epsilon)
# def forward_hip(
# self,
# x: torch.Tensor,
# residual: Optional[torch.Tensor] = None,
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
# if not x.is_contiguous():
# # NOTE: Remove this if aiter kernel supports discontinuous input
# x = x.contiguous()
# if residual is not None:
# if _vllm_version < Version("0.9"):
# fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
# return x, residual
# else:
# residual_out = torch.empty_like(x)
# output = torch.empty_like(x)
# fused_add_rms_norm(
# output,
# x,
# residual_out,
# residual,
# self.weight.data,
# self.variance_epsilon,
# )
# return output, residual_out
# out = torch.empty_like(x)
# rms_norm(out, x, self.weight.data, self.variance_epsilon)
# return out
def forward_hip( def forward_hip(
self, self,
x: torch.Tensor, x: torch.Tensor,
residual: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: ):
if not x.is_contiguous(): if not x.is_contiguous():
# NOTE: Remove this if aiter kernel supports discontinuous input
x = x.contiguous() x = x.contiguous()
if residual is not None: if residual is not None:
if _vllm_version < Version("0.9"): try:
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon) fused_add_rms_norm(
x,
residual,
self.weight.data,
self.variance_epsilon,
)
return x, residual return x, residual
else: except TypeError:
residual_out = torch.empty_like(x)
output = torch.empty_like(x) output = torch.empty_like(x)
residual_out = torch.empty_like(x)
fused_add_rms_norm( fused_add_rms_norm(
output, output,
x, x,
...@@ -184,10 +216,13 @@ class RMSNorm(CustomOp): ...@@ -184,10 +216,13 @@ class RMSNorm(CustomOp):
self.variance_epsilon, self.variance_epsilon,
) )
return output, residual_out return output, residual_out
out = torch.empty_like(x) out = torch.empty_like(x)
rms_norm(out, x, self.weight.data, self.variance_epsilon) rms_norm(out, x, self.weight.data, self.variance_epsilon)
return out return out
def forward_native( def forward_native(
self, self,
x: torch.Tensor, x: torch.Tensor,
......
...@@ -45,6 +45,18 @@ _is_hip = is_hip() ...@@ -45,6 +45,18 @@ _is_hip = is_hip()
_disable_hip_linear_quant = _is_hip and get_bool_env_var( _disable_hip_linear_quant = _is_hip and get_bool_env_var(
"SGLANG_ROCM_DISABLE_LINEARQUANT" "SGLANG_ROCM_DISABLE_LINEARQUANT"
) )
_use_fused_rms_quant = get_bool_env_var("SGLANG_USE_FUSED_RMS_QUANT")
_use_fused_silu_mul_quant = get_bool_env_var("SGLANG_USE_FUSED_SILU_MUL_QUANT")
if _use_fused_rms_quant:
try:
from lmslim.quantize.quant_ops import lm_faster_rmsquant
except Exception as e:
print(f"Error: Import fused rmsquant error: {e}")
if _use_fused_silu_mul_quant:
try:
from lmslim.quantize.quant_ops import lm_fuse_silu_mul_quant
except Exception as e:
print(f"Error: Import fused silu_mul_quant error: {e}")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -1360,7 +1372,7 @@ class RowParallelLinear(LinearBase): ...@@ -1360,7 +1372,7 @@ class RowParallelLinear(LinearBase):
# It does not support additional parameters. # It does not support additional parameters.
param.load_row_parallel_weight(loaded_weight) param.load_row_parallel_weight(loaded_weight)
def forward(self, input_, skip_all_reduce=False): def forward(self, input_, skip_all_reduce=False, use_fused_silu_mul_quant=False):
if self.input_is_parallel: if self.input_is_parallel:
input_parallel = input_ input_parallel = input_
else: else:
...@@ -1374,10 +1386,19 @@ class RowParallelLinear(LinearBase): ...@@ -1374,10 +1386,19 @@ class RowParallelLinear(LinearBase):
# Only fuse bias add into GEMM for rank 0 (this ensures that # Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case) # bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
with use_symmetric_memory( if use_fused_silu_mul_quant:
get_tp_group(), disabled=not is_allocation_symmetric() xq, xs = lm_fuse_silu_mul_quant(input_parallel)
): silu_quant_args = [xq, xs]
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_) with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel,
bias=bias_,
silu_quant_args=silu_quant_args
)
sm.tag(output_parallel)
else:
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
sm.tag(output_parallel)
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce: if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
output = tensor_model_parallel_all_reduce(output_parallel) output = tensor_model_parallel_all_reduce(output_parallel)
......
...@@ -2,7 +2,6 @@ import logging ...@@ -2,7 +2,6 @@ import logging
import torch import torch
import triton import triton
from sglang.srt.utils import ceil_div, is_cuda from sglang.srt.utils import ceil_div, is_cuda
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -1015,196 +1014,133 @@ def zero_experts_compute_triton( ...@@ -1015,196 +1014,133 @@ def zero_experts_compute_triton(
return output return output
from triton.language.extra import libdevice
from typing import Optional
@triton.jit @triton.jit
def compute_problem_sizes_w4a8_kernel( def _per_token_quant_int8_one_kernel_opt(
masked_m_ptr, x_ptr,
problem_sizes1_ptr, xq_ptr,
problem_sizes2_ptr, scale_ptr,
n, stride_x,
k, stride_xq,
num_experts, N,
BLOCK_SIZE: tl.constexpr, T_dim,
): tokens_per_expert_ptr,
pid = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) BLOCK: tl.constexpr
mask = pid < num_experts
final_occurrences = tl.load(masked_m_ptr + pid, mask=mask, other=0)
ps1_idx_0 = pid * 3
ps1_idx_1 = ps1_idx_0 + 1
ps1_idx_2 = ps1_idx_0 + 2
ps2_idx_0 = pid * 3
ps2_idx_1 = ps2_idx_0 + 1
ps2_idx_2 = ps2_idx_0 + 2
ps1_mask_0 = ps1_idx_0 < num_experts * 3
ps1_mask_1 = ps1_idx_1 < num_experts * 3
ps1_mask_2 = ps1_idx_2 < num_experts * 3
ps2_mask_0 = ps2_idx_0 < num_experts * 3
ps2_mask_1 = ps2_idx_1 < num_experts * 3
ps2_mask_2 = ps2_idx_2 < num_experts * 3
tl.store(problem_sizes1_ptr + ps1_idx_0, 2 * n, mask=ps1_mask_0)
tl.store(problem_sizes1_ptr + ps1_idx_1, final_occurrences, mask=ps1_mask_1)
tl.store(problem_sizes1_ptr + ps1_idx_2, k, mask=ps1_mask_2)
tl.store(problem_sizes2_ptr + ps2_idx_0, k, mask=ps2_mask_0)
tl.store(problem_sizes2_ptr + ps2_idx_1, final_occurrences, mask=ps2_mask_1)
tl.store(problem_sizes2_ptr + ps2_idx_2, n, mask=ps2_mask_2)
def compute_problem_sizes_w4a8(
masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
):
BLOCK_SIZE = 256
grid = lambda meta: (triton.cdiv(num_experts, meta["BLOCK_SIZE"]),)
compute_problem_sizes_w4a8_kernel[grid](
masked_m,
problem_sizes1,
problem_sizes2,
n,
k,
num_experts,
BLOCK_SIZE=BLOCK_SIZE,
)
return problem_sizes1, problem_sizes2
def deepep_ll_get_cutlass_w4a8_moe_mm_data(
masked_m,
problem_sizes1,
problem_sizes2,
num_experts,
n,
k,
): ):
problem_sizes1, problem_sizes2 = compute_problem_sizes_w4a8( row_id = tl.program_id(0)
masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
) if tokens_per_expert_ptr is not None:
return ( e = row_id // T_dim
problem_sizes1.to(torch.int32), t = row_id % T_dim
problem_sizes2.to(torch.int32),
) num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
if t >= num_valid_tokens_for_e:
return
cols = tl.arange(0, BLOCK)
mask = cols < N
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = x * (127 / absmax)
x_q = libdevice.nearbyint(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
@triton.jit @triton.jit
def _silu_and_mul_post_per_tensor_quant_kernel( def _per_token_quant_int8_kernel_opt(
input_ptr, x_ptr,
stride_input_expert, xq_ptr,
stride_input_token,
stride_input_dim,
output_ptr,
stride_output_expert,
stride_output_token,
stride_output_dim,
scale_ptr, scale_ptr,
masked_m_ptr, stride_x,
inner_dim, stride_xq,
fp8_max, N,
fp8_min, E_dim,
BLOCK_N: tl.constexpr, T_dim,
NUM_STAGE: tl.constexpr, tokens_per_expert_ptr,
BLOCK: tl.constexpr
): ):
""" token_idx_start = tl.program_id(0)
Triton kernel: fused SiLU(gate) * up + per-tensor FP8 quantization. grid_size = tl.num_programs(0)
num_total_tokens = E_dim * T_dim
Shape:
input: [E, T_padded, 2*D] -> gate: [:,:,D], up: [:,:,D] for token_idx in range(token_idx_start, num_total_tokens, grid_size):
output: [E, T_padded, D], dtype=float8_e4m3fn
""" is_valid_token = True
expert_id = tl.program_id(2) if tokens_per_expert_ptr is not None:
block_id_token = tl.program_id(1) e = token_idx // T_dim
block_id_dim = tl.program_id(0) t = token_idx % T_dim
num_token_blocks = tl.num_programs(1) num_valid_tokens_for_e = tl.load(tokens_per_expert_ptr + e)
token_num_cur_expert = tl.load(masked_m_ptr + expert_id) if t >= num_valid_tokens_for_e:
is_valid_token = False
scale = 1.0 / tl.load(scale_ptr).to(tl.float32)
if is_valid_token:
stride_input_expert = tl.cast(stride_input_expert, tl.int32) cols = tl.arange(0, BLOCK)
stride_output_expert = tl.cast(stride_output_expert, tl.int32) mask = cols < N
stride_input_token = tl.cast(stride_input_token, tl.int32)
stride_output_token = tl.cast(stride_output_token, tl.int32) x = tl.load(x_ptr + token_idx * stride_x + cols, mask=mask,
other=0.0).to(tl.float32)
offset_d = block_id_dim * BLOCK_N + tl.arange(0, BLOCK_N) absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
mask_d = offset_d < inner_dim scale_x = absmax / 127
x_q = x * (127 / absmax)
# base pointers for current expert and dim block x_q = libdevice.nearbyint(x_q).to(tl.int8)
input_base_offs = input_ptr + expert_id * stride_input_expert + offset_d
output_base_offs = output_ptr + expert_id * stride_output_expert + offset_d tl.store(xq_ptr + token_idx * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + token_idx, scale_x)
for token_idx in tl.range(
block_id_token, token_num_cur_expert, num_token_blocks, num_stages=NUM_STAGE
): def per_token_quant_int8_triton_opt(x: torch.Tensor,
gate_ptr = input_base_offs + token_idx * stride_input_token tokens_per_expert: Optional[torch.Tensor] = None):
up_ptr = gate_ptr + inner_dim if x.dim() != 3:
gate = tl.load(gate_ptr, mask=mask_d, other=0.0).to(tl.float32) raise ValueError(f"Input must be 3D [E, T, H], but got {x.shape}")
up = tl.load(up_ptr, mask=mask_d, other=0.0).to(tl.float32) E, T, H = x.shape
N = H
# SiLU: x * sigmoid(x)
gate = gate / (1 + tl.exp(-gate)) x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
gate = gate.to(input_ptr.dtype.element_ty) scales = torch.empty(x.shape[:-1] + (1, ), device=x.device, dtype=torch.float32)
gate_up = up * gate BLOCK = triton.next_power_of_2(N)
num_warps = min(max(BLOCK // 256, 1), 8)
scaled = gate_up * scale if (E == 8 and T >= 1024) or (E == 16 and T >= 512):
output_q = tl.clamp(scaled, fp8_min, fp8_max).to(output_ptr.dtype.element_ty) num_warps = 1
out_ptr = output_base_offs + token_idx * stride_output_token
tl.store(out_ptr, output_q, mask=mask_d) num_tokens = E * T
grid_opt = num_tokens
def silu_and_mul_masked_post_per_tensor_quant_fwd( if (E == 8 and T >= 1024) or (E == 16 and T >= 512):
input: torch.Tensor, grid_opt = max(1, num_tokens // (T // 256))
output: torch.Tensor, _per_token_quant_int8_kernel_opt[(grid_opt, )](
masked_m: torch.Tensor, x,
scale: torch.Tensor, x_q,
) -> torch.Tensor: scales,
""" stride_x=x.stride(-2),
Fused SiLU + Mul + Per-Tensor Quantization to FP8. stride_xq=x_q.stride(-2),
N=N,
Args: E_dim=E,
input: [expert_num, token_num_padded, 2 * inner_dim] T_dim=T,
output: [expert_num, token_num_padded, inner_dim], dtype=torch.float8_e4m3fn tokens_per_expert_ptr=tokens_per_expert,
masked_m: [expert_num], actual token count for each expert BLOCK=BLOCK,
scale: [1] or [expert_num], quantization scale (per-tensor or per-expert) num_warps=num_warps,
num_stages=1,
Returns: )
output tensor else:
""" _per_token_quant_int8_one_kernel_opt[(grid_opt, )](
assert input.is_contiguous() x,
assert output.is_contiguous() x_q,
assert output.dtype == torch.float8_e4m3fn scales,
assert input.ndim == 3 stride_x=x.stride(-2),
assert input.shape[0] == masked_m.shape[0] stride_xq=x_q.stride(-2),
assert input.shape[-1] % 2 == 0 N=N,
assert scale.numel() == 1 or scale.shape[0] == input.shape[0] T_dim=T,
tokens_per_expert_ptr=tokens_per_expert,
expert_num = input.shape[0] BLOCK=BLOCK,
# 3584 num_warps=num_warps,
inner_dim = input.shape[-1] // 2 num_stages=1,
)
BLOCK_N = 256 return x_q, scales
BLOCK_M = 64 if expert_num < 4 else 32
NUM_STAGES = 3
hidden_dim_split_block_num = triton.cdiv(inner_dim, BLOCK_N)
grid = (hidden_dim_split_block_num, BLOCK_M, expert_num)
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_max = finfo.max
fp8_min = -fp8_max
_silu_and_mul_post_per_tensor_quant_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
scale,
masked_m,
inner_dim,
fp8_max,
fp8_min,
BLOCK_N=BLOCK_N,
NUM_STAGE=NUM_STAGES,
)
return output
...@@ -65,7 +65,7 @@ def inplace_fused_experts( ...@@ -65,7 +65,7 @@ def inplace_fused_experts(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None, b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None,
activation: str = "silu", activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
...@@ -84,6 +84,8 @@ def inplace_fused_experts( ...@@ -84,6 +84,8 @@ def inplace_fused_experts(
gemm1_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
filter_expert: bool = True, filter_expert: bool = True,
) -> None: ) -> None:
if isinstance(activation, int):
activation = "silu" if activation == 0 else "gelu"
fused_experts_impl( fused_experts_impl(
hidden_states, hidden_states,
w1, w1,
...@@ -123,7 +125,7 @@ def inplace_fused_experts_fake( ...@@ -123,7 +125,7 @@ def inplace_fused_experts_fake(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None, b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None,
activation: str = "silu", activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
...@@ -161,7 +163,7 @@ def outplace_fused_experts( ...@@ -161,7 +163,7 @@ def outplace_fused_experts(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None, b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None,
activation: str = "silu", activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
...@@ -181,6 +183,8 @@ def outplace_fused_experts( ...@@ -181,6 +183,8 @@ def outplace_fused_experts(
gemm1_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
filter_expert: bool = True, filter_expert: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
if isinstance(activation, int):
activation = "silu" if activation == 0 else "gelu"
return fused_experts_impl( return fused_experts_impl(
hidden_states, hidden_states,
w1, w1,
...@@ -220,7 +224,7 @@ def outplace_fused_experts_fake( ...@@ -220,7 +224,7 @@ def outplace_fused_experts_fake(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
b1: Optional[torch.Tensor] = None, b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None,
activation: str = "silu", activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
...@@ -273,9 +277,12 @@ def fused_experts( ...@@ -273,9 +277,12 @@ def fused_experts(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
): ):
topk_weights, topk_ids, _ = topk_output topk_weights, topk_ids, _ = topk_output
filter_expert = ( act_id = (
moe_runner_config.num_experts is None 0 if (
or moe_runner_config.num_experts != moe_runner_config.num_local_experts moe_runner_config.activation == 0
or (isinstance(moe_runner_config.activation, str)
and moe_runner_config.activation.lower() == "silu")
) else 1
) )
if moe_runner_config.inplace: if moe_runner_config.inplace:
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense" assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
...@@ -287,7 +294,7 @@ def fused_experts( ...@@ -287,7 +294,7 @@ def fused_experts(
topk_ids, topk_ids,
b1, b1,
b2, b2,
moe_runner_config.activation, act_id,
moe_runner_config.apply_router_weight_on_input, moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
...@@ -316,7 +323,7 @@ def fused_experts( ...@@ -316,7 +323,7 @@ def fused_experts(
topk_ids, topk_ids,
b1, b1,
b2, b2,
moe_runner_config.activation, act_id,
moe_runner_config.apply_router_weight_on_input, moe_runner_config.apply_router_weight_on_input,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a8, use_int8_w8a8,
...@@ -366,7 +373,7 @@ def fused_experts_impl( ...@@ -366,7 +373,7 @@ def fused_experts_impl(
b1: Optional[torch.Tensor] = None, b1: Optional[torch.Tensor] = None,
b2: Optional[torch.Tensor] = None, b2: Optional[torch.Tensor] = None,
inplace: bool = False, inplace: bool = False,
activation: str = "silu", activation: int = 0,#0 silu 1 gelu
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False, use_fp8_w8a8: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
...@@ -386,6 +393,8 @@ def fused_experts_impl( ...@@ -386,6 +393,8 @@ def fused_experts_impl(
gemm1_limit: Optional[float] = None, gemm1_limit: Optional[float] = None,
filter_expert: bool = True, filter_expert: bool = True,
): ):
if isinstance(activation, int):
activation = "silu" if activation == 0 else "gelu"
padded_size = padding_size padded_size = padding_size
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter: if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
padded_size = 0 padded_size = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment