Unverified Commit 71046fcd authored by Meng, Hengyu's avatar Meng, Hengyu Committed by GitHub
Browse files

[XPU][CPU] Enable the native path of DeepSeek (#4086)


Co-authored-by: default avatarZhang, Liangang <liangang.zhang@intel.com>
parent c76040e3
......@@ -52,6 +52,15 @@ cd ..
pip install -e "python[all_hip]"
```
Note: To Intel GPU, do following instead:
```
git clone https://github.com/sgl-project/sglang.git
cd sglang
pip install --upgrade pip
pip install -e "python[all_xpu]"
```
## Method 3: Using docker
The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).
Replace `<secret>` below with your huggingface hub [token](https://huggingface.co/docs/hub/en/security-tokens).
......
......@@ -57,7 +57,7 @@ srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1
# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
srt_xpu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"]
srt_xpu = ["sglang[runtime_common]", "vllm>=0.6.4.post1,<=0.7.2", "outlines>=0.0.44,<=0.1.11"]
# For Intel Gaudi(device : hpu) follow the installation guide
# https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
......
......@@ -54,8 +54,18 @@ class QuantizationConfig(ABC):
"""Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
This requirement is due to the custom CUDA kernels used by the
quantization method.
This requirement is due to the custom kernels used by the
quantization method or the stock pytorch capability.
"""
raise NotImplementedError
@classmethod
@abstractmethod
def get_availability(cls) -> bool:
"""Whether the quantization config is available on current device.
This requirement is due to the custom kernels used by the
quantization method or the stock pytorch capability.
"""
raise NotImplementedError
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
import logging
import sys
from typing import Any, Callable, Dict, List, Optional
import torch
......@@ -19,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
from sglang.srt.utils import set_weight_attrs
from sglang.srt.utils import get_device_capability, set_weight_attrs
ACTIVATION_SCHEMES = ["static", "dynamic"]
......@@ -71,7 +72,20 @@ class BlockInt8Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 80
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 80
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 80
# Vendors can update
return False
@classmethod
def get_config_filenames(cls) -> List[str]:
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
import logging
import sys
from typing import Any, Callable, Dict, List, Optional
import torch
......@@ -36,7 +37,11 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.layers.quantization.fp8_kernel import (
native_per_token_group_quant_fp8,
native_w8a8_block_fp8_matmul,
per_token_group_quant_fp8,
)
from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear,
apply_w8a8_block_fp8_linear,
......@@ -46,6 +51,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
)
from sglang.srt.utils import (
get_bool_env_var,
get_device_capability,
is_cuda,
is_hip,
permute_weight,
print_warning_once,
......@@ -55,6 +62,7 @@ from sglang.srt.utils import (
ACTIVATION_SCHEMES = ["static", "dynamic"]
_is_hip = is_hip()
_is_cuda = is_cuda()
if _is_hip:
from aiter.fused_moe_bf16_asm import asm_moe
......@@ -108,7 +116,24 @@ class Fp8Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 80
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 80
if hasattr(torch, "xpu") and torch.xpu.is_available():
return 0
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 80
if hasattr(torch, "xpu") and torch.xpu.is_available():
return True
# Vendors can update
return False
@classmethod
def get_config_filenames(cls) -> List[str]:
......@@ -850,6 +875,52 @@ class Fp8MoEMethod:
)
torch.cuda.empty_cache()
def torch_w8a8_block_fp8_moe(
self, a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape
):
from sglang.srt.layers.activation import SiluAndMul
"""This function performs fused moe with block-wise quantization using native torch."""
B, D = a.shape
topk = topk_ids.shape[-1]
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
# NOTE(HandH1998): Since "index_cuda" not implemented for 'Float8_e4m3fn', we need to cast `float8`` to `float32``.
a_q = a_q.to(torch.float32)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_fp8_matmul(
a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype,
)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_fp8(
act_out, block_k
)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_fp8_matmul(
act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype,
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def apply(
self,
layer: torch.nn.Module,
......@@ -923,7 +994,7 @@ class Fp8MoEMethod:
layer.w13_weight_scale1,
layer.w2_weight_scale1,
)
else:
elif _is_cuda:
# Expert fusion with FP8 quantization
return fused_experts(
x,
......@@ -950,6 +1021,24 @@ class Fp8MoEMethod:
no_combine=no_combine,
)
# for CPU and other accelerators, fallback to native path
return self.torch_w8a8_block_fp8_moe(
a=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_s=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_s=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
topk_weight=topk_weights,
topk_ids=topk_ids,
block_shape=self.quant_config.weight_block_size,
)
class Fp8KVCacheMethod(BaseKVCacheMethod):
"""
......
......@@ -28,10 +28,13 @@ from sglang.srt.utils import (
get_device_name,
is_cuda,
is_hip,
is_triton_available,
supports_custom_op,
)
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_triton = is_triton_available()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
_is_cuda = is_cuda()
......@@ -162,6 +165,34 @@ def _per_token_group_quant_fp8_colmajor(
tl.store(y_s_ptr, y_s)
def native_per_token_group_quant_fp8(
x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn
):
"""Function to perform per-token-group quantization on an input tensor `x` using native torch.
It converts the tensor values into float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Note that only `torch.float8_e4m3fn` is supported for now.
"""
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
return x_q, x_s
def per_token_group_quant_fp8(
x: torch.Tensor,
group_size: int,
......@@ -232,19 +263,22 @@ def per_token_group_quant_fp8(
num_stages=num_stages,
)
else:
_per_token_group_quant_fp8[(M,)](
x,
x_q,
x_s,
group_size,
N,
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
if _is_triton:
_per_token_group_quant_fp8[(M,)](
x,
x_q,
x_s,
group_size,
N,
eps,
fp8_min=fp8_min,
fp8_max=fp8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
)
else:
x_q, x_s = native_per_token_group_quant_fp8(x, group_size)
return x_q, x_s
......@@ -691,6 +725,61 @@ def get_w8a8_block_fp8_configs(
return None
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""This function performs matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N,)
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
B_tiles = [
[
B[
j * block_n : min((j + 1) * block_n, N),
i * block_k : min((i + 1) * block_k, K),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
return C
def w8a8_block_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
......@@ -715,86 +804,90 @@ def w8a8_block_fp8_matmul(
Returns:
torch.Tensor: The result of matmul.
"""
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
if _is_triton: # pragma: no cover
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert A.shape[-1] == B.shape[-1]
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
M = A.numel() // A.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
N, K = B.shape
assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"])
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
if configs:
# If an optimal configuration map has been found, look up the
# optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_size[0],
"BLOCK_SIZE_K": block_size[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 3,
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
# Empirical testing shows the sweet spot lies when it's less than the # of
# compute units available on the device.
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
N, config["BLOCK_SIZE_N"]
)
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
# Empirical testing shows the sweet spot lies when it's less than the # of
# compute units available on the device.
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
N, config["BLOCK_SIZE_N"]
)
# deepgemm only support bf16
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
# deepgemm only support bf16
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
if supports_custom_op():
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
else:
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
else:
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
else:
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
if (_is_hip == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
)
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
if (_is_hip == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
)
kernel[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
kernel[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
else:
C = native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype)
return C
import logging
import sys
from fractions import Fraction
from typing import Any, Dict, List, Optional, Union
......@@ -8,6 +9,7 @@ from vllm.scalar_type import scalar_types
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.utils import get_device_capability
logger = logging.getLogger(__name__)
......@@ -90,7 +92,20 @@ class GPTQConfig(QuantizationConfig):
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 60
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 60
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 60
# Vendors can update
return False
@classmethod
def get_config_filenames(cls) -> List[str]:
......@@ -209,7 +224,20 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 80
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 80
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 80
# Vendors can update
return False
@classmethod
def get_config_filenames(cls) -> List[str]:
......@@ -371,7 +399,20 @@ class MarlinConfig(QuantizationConfig):
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 80
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 80
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 80
# Vendors can update
return False
@classmethod
def get_config_filenames(cls) -> List[str]:
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
import logging
import sys
from typing import Any, Dict, List, Optional
import torch
......@@ -20,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
from sglang.srt.utils import get_device_capability
# Initialize logger for the module
logger = logging.getLogger(__name__)
......@@ -52,7 +54,20 @@ class ModelOptFp8Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 89 # Minimum hardware capability (e.g., Hopper GPUs).
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 89
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 89
# Vendors can update
return False
@classmethod
def get_config_filenames(cls) -> List[str]:
......
......@@ -14,7 +14,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz,
)
from sglang.srt.utils import is_hip
from sglang.srt.utils import get_device_capability, is_hip
_is_hip = is_hip()
......@@ -35,7 +35,20 @@ class W8A8Fp8Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 89
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 89
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 89
# Vendors can update
return False
@classmethod
def get_name(self) -> str:
......
import sys
from typing import Any, Callable, Dict, List, Optional
import torch
......@@ -18,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.utils import get_device_capability
class W8A8Int8Config(QuantizationConfig):
......@@ -36,7 +38,20 @@ class W8A8Int8Config(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 75
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 75
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> int:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 75
# Vendors can update
return False
@classmethod
def get_name(self) -> str:
......
......@@ -69,6 +69,7 @@ class RotaryEmbedding(CustomOp):
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
) -> None:
super().__init__()
self.head_size = head_size
......@@ -77,6 +78,7 @@ class RotaryEmbedding(CustomOp):
self.base = base
self.is_neox_style = is_neox_style
self.dtype = dtype
self.device = device
cache = self._compute_cos_sin_cache()
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
......@@ -283,12 +285,19 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
scaling_factors: Union[List[float], float],
dtype: torch.dtype,
device: str,
) -> None:
if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors]
self.scaling_factors: List[float] = scaling_factors # noqa
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
)
# Lazy initialized.
self._scaling_factor_to_offset: Dict[float, int]
......@@ -347,10 +356,17 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
device: str,
) -> None:
self.scaling_factor = scaling_factor
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
)
def _compute_cos_sin_cache(self) -> torch.Tensor:
......@@ -434,6 +450,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
device: str,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
......@@ -448,7 +465,13 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
# Get n-d magnitude scaling corrected for interpolation
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
......@@ -645,6 +668,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool,
scaling_factor: float,
dtype: torch.dtype,
device: str,
*,
extrapolation_factor: float = 1,
attn_factor: float = 1,
......@@ -652,7 +676,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
beta_slow: int = 1,
mscale: float = 1,
mscale_all_dim: float = 0,
device: Optional[str] = "cuda",
) -> None:
self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor
......@@ -665,9 +688,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
* attn_factor
)
self.device = device
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
)
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
......@@ -762,6 +790,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
scaling_factor: float,
low_freq_factor: float,
high_freq_factor: float,
......@@ -772,7 +801,13 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
self.high_freq_factor = high_freq_factor
self.orig_max_position = orig_max_position
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
str,
)
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
......@@ -810,10 +845,17 @@ class MRotaryEmbedding(RotaryEmbedding):
base: int,
is_neox_style: bool,
dtype: torch.dtype,
device: str,
mrope_section: Optional[List[int]] = None,
) -> None:
super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
)
self.mrope_section = mrope_section
......@@ -1003,9 +1045,14 @@ def get_rope(
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
device: str = None,
) -> RotaryEmbedding:
if dtype is None:
dtype = torch.get_default_dtype()
if device is None:
from sglang.srt.managers.schedule_batch import global_server_args_dict
device = global_server_args_dict["device"]
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
......@@ -1030,7 +1077,7 @@ def get_rope(
if rope_scaling is None:
rotary_emb = RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype
head_size, rotary_dim, max_position, base, is_neox_style, dtype, device
)
else:
if "rope_type" in rope_scaling:
......@@ -1052,6 +1099,7 @@ def get_rope(
base,
is_neox_style,
dtype,
device,
scaling_factor,
low_freq_factor,
high_freq_factor,
......@@ -1066,6 +1114,7 @@ def get_rope(
base,
is_neox_style,
dtype,
device,
mrope_section=rope_scaling["mrope_section"],
)
else:
......@@ -1076,6 +1125,7 @@ def get_rope(
base,
is_neox_style,
dtype,
device,
)
elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"]
......@@ -1087,6 +1137,7 @@ def get_rope(
is_neox_style,
scaling_factor,
dtype,
device,
)
elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"]
......@@ -1098,6 +1149,7 @@ def get_rope(
is_neox_style,
scaling_factor,
dtype,
device,
)
elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"]
......@@ -1116,6 +1168,7 @@ def get_rope(
is_neox_style,
scaling_factor,
dtype,
device,
**extra_kwargs,
)
elif scaling_type == "deepseek_yarn":
......@@ -1143,6 +1196,7 @@ def get_rope(
is_neox_style,
scaling_factor,
dtype,
device,
**extra_kwargs,
)
elif scaling_type == "longrope":
......@@ -1253,21 +1307,8 @@ def get_rope_wrapper(
rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0,
device: Optional[str] = None,
):
if device != "cpu":
return get_rope(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling,
dtype,
partial_rotary_factor,
)
return get_rope_cpu(
return get_rope(
head_size,
rotary_dim,
max_position,
......@@ -1276,5 +1317,4 @@ def get_rope_wrapper(
rope_scaling,
dtype,
partial_rotary_factor,
device,
)
......@@ -35,6 +35,8 @@ from sglang.srt.distributed import (
set_custom_all_reduce,
)
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
get_attention_tp_size,
......
......@@ -108,15 +108,25 @@ def _get_quantization_config(
quant_config = get_quant_config(model_config, load_config)
major, minor = get_device_capability()
if major is not None and minor is not None:
assert 0 <= minor < 10
capability = major * 10 + minor
if capability < quant_config.get_min_capability():
if not hasattr(quant_config, "get_availability"):
# Update VLLM to support get_available
if major is not None and minor is not None:
assert 0 <= minor < 10
capability = major * 10 + minor
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} "
"is not supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}."
)
else:
if not quant_config.get_availability():
raise ValueError(
f"The quantization method {model_config.quantization} "
"is not supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}."
f"Current capability: {major, minor}."
)
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
......
......@@ -55,7 +55,7 @@ from sglang.srt.layers.quantization.int8_utils import (
block_dequant as int8_block_dequant,
)
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......@@ -305,7 +305,6 @@ class DeepseekV2Attention(nn.Module):
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False,
device=global_server_args_dict["device"],
)
if rope_scaling:
......@@ -501,7 +500,7 @@ class DeepseekV2AttentionMLA(nn.Module):
if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope(
self.rotary_emb = get_rope_wrapper(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
......@@ -646,19 +645,20 @@ class DeepseekV2AttentionMLA(nn.Module):
)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
if self.w_kc.dtype == torch.float8_e4m3fnuz:
if self.w_kc.dtype == torch.float8_e4m3fnuz: # hip only
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1),
self.w_kc.to(torch.bfloat16) * self.w_scale,
)
elif self.w_kc.dtype == torch.float8_e4m3fn:
elif self.w_kc.dtype == torch.float8_e4m3fn and is_cuda_available():
q_nope_val, q_nope_scale = input_to_float8(
q_nope.transpose(0, 1), torch.float8_e4m3fn
)
q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
)
else:
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
......@@ -677,13 +677,13 @@ class DeepseekV2AttentionMLA(nn.Module):
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.w_vc.dtype == torch.float8_e4m3fnuz:
if self.w_vc.dtype == torch.float8_e4m3fnuz or not is_cuda_available():
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
attn_bmm_output = torch.bmm(
attn_output.to(torch.bfloat16).transpose(0, 1),
self.w_vc.to(torch.bfloat16) * self.w_scale,
)
elif self.w_vc.dtype == torch.float8_e4m3fn:
elif self.w_vc.dtype == torch.float8_e4m3fn and is_cuda_available():
attn_output_val, attn_output_scale = input_to_float8(
attn_output.transpose(0, 1), torch.float8_e4m3fn
)
......@@ -694,6 +694,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_scale,
torch.bfloat16,
)
else:
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
......
......@@ -126,6 +126,14 @@ def is_cuda_available():
return is_cuda()
def is_triton_available():
if is_cuda() or is_xpu() or is_hip():
return get_bool_env_var("TRITON_AVAILABLE", default="true")
else:
# update once CPU/HPU supports triton
return False
def enable_show_time_cost():
global show_time_cost
show_time_cost = True
......@@ -1136,6 +1144,7 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
major, minor = None, None
if hasattr(torch, "cuda") and torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability(device_id)
assert 0 <= minor < 10
if hasattr(torch, "xpu") and torch.xpu.is_available():
major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split(
......
......@@ -7,6 +7,8 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import (
native_per_token_group_quant_fp8,
native_w8a8_block_fp8_matmul,
per_token_group_quant_fp8,
static_quant_fp8,
w8a8_block_fp8_matmul,
......@@ -15,35 +17,6 @@ from sglang.srt.layers.quantization.fp8_kernel import (
_is_cuda = torch.cuda.is_available() and torch.version.cuda
# For test
def native_per_token_group_quant_fp8(
x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn
):
"""Function to perform per-token-group quantization on an input tensor `x` using native torch.
It converts the tensor values into float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Note that only `torch.float8_e4m3fn` is supported for now.
"""
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
return x_q, x_s
class TestPerTokenGroupQuantFP8(unittest.TestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048]
......@@ -154,62 +127,6 @@ class TestStaticQuantFP8(unittest.TestCase):
self._static_quant_fp8(*params)
# For test
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""This function performs matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N,)
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
B_tiles = [
[
B[
j * block_n : min((j + 1) * block_n, N),
i * block_k : min((i + 1) * block_k, K),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
return C
class TestW8A8BlockFP8Matmul(unittest.TestCase):
if not _is_cuda:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment