Unverified Commit 71046fcd authored by Meng, Hengyu's avatar Meng, Hengyu Committed by GitHub
Browse files

[XPU][CPU] Enable the native path of DeepSeek (#4086)


Co-authored-by: default avatarZhang, Liangang <liangang.zhang@intel.com>
parent c76040e3
...@@ -52,6 +52,15 @@ cd .. ...@@ -52,6 +52,15 @@ cd ..
pip install -e "python[all_hip]" pip install -e "python[all_hip]"
``` ```
Note: To Intel GPU, do following instead:
```
git clone https://github.com/sgl-project/sglang.git
cd sglang
pip install --upgrade pip
pip install -e "python[all_xpu]"
```
## Method 3: Using docker ## Method 3: Using docker
The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker). The docker images are available on Docker Hub as [lmsysorg/sglang](https://hub.docker.com/r/lmsysorg/sglang/tags), built from [Dockerfile](https://github.com/sgl-project/sglang/tree/main/docker).
Replace `<secret>` below with your huggingface hub [token](https://huggingface.co/docs/hub/en/security-tokens). Replace `<secret>` below with your huggingface hub [token](https://huggingface.co/docs/hub/en/security-tokens).
......
...@@ -57,7 +57,7 @@ srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1 ...@@ -57,7 +57,7 @@ srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1
# xpu is not enabled in public vllm and torch whl, # xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm # need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
srt_xpu = ["sglang[runtime_common]", "outlines>=0.0.44,<=0.1.11"] srt_xpu = ["sglang[runtime_common]", "vllm>=0.6.4.post1,<=0.7.2", "outlines>=0.0.44,<=0.1.11"]
# For Intel Gaudi(device : hpu) follow the installation guide # For Intel Gaudi(device : hpu) follow the installation guide
# https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html # https://docs.vllm.ai/en/latest/getting_started/gaudi-installation.html
......
...@@ -54,8 +54,18 @@ class QuantizationConfig(ABC): ...@@ -54,8 +54,18 @@ class QuantizationConfig(ABC):
"""Minimum GPU capability to support the quantization method. """Minimum GPU capability to support the quantization method.
E.g., 70 for Volta, 75 for Turing, 80 for Ampere. E.g., 70 for Volta, 75 for Turing, 80 for Ampere.
This requirement is due to the custom CUDA kernels used by the This requirement is due to the custom kernels used by the
quantization method. quantization method or the stock pytorch capability.
"""
raise NotImplementedError
@classmethod
@abstractmethod
def get_availability(cls) -> bool:
"""Whether the quantization config is available on current device.
This requirement is due to the custom kernels used by the
quantization method or the stock pytorch capability.
""" """
raise NotImplementedError raise NotImplementedError
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
import logging import logging
import sys
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
...@@ -19,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -19,7 +20,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import get_device_capability, set_weight_attrs
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
...@@ -71,8 +72,21 @@ class BlockInt8Config(QuantizationConfig): ...@@ -71,8 +72,21 @@ class BlockInt8Config(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 80 return 80
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 80
# Vendors can update
return False
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
return [] return []
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
import logging import logging
import sys
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
...@@ -36,7 +37,11 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -36,7 +37,11 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import (
native_per_token_group_quant_fp8,
native_w8a8_block_fp8_matmul,
per_token_group_quant_fp8,
)
from sglang.srt.layers.quantization.fp8_utils import ( from sglang.srt.layers.quantization.fp8_utils import (
apply_fp8_linear, apply_fp8_linear,
apply_w8a8_block_fp8_linear, apply_w8a8_block_fp8_linear,
...@@ -46,6 +51,8 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -46,6 +51,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
) )
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
get_device_capability,
is_cuda,
is_hip, is_hip,
permute_weight, permute_weight,
print_warning_once, print_warning_once,
...@@ -55,6 +62,7 @@ from sglang.srt.utils import ( ...@@ -55,6 +62,7 @@ from sglang.srt.utils import (
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda()
if _is_hip: if _is_hip:
from aiter.fused_moe_bf16_asm import asm_moe from aiter.fused_moe_bf16_asm import asm_moe
...@@ -108,7 +116,24 @@ class Fp8Config(QuantizationConfig): ...@@ -108,7 +116,24 @@ class Fp8Config(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 80 return 80
if hasattr(torch, "xpu") and torch.xpu.is_available():
return 0
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 80
if hasattr(torch, "xpu") and torch.xpu.is_available():
return True
# Vendors can update
return False
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
...@@ -850,6 +875,52 @@ class Fp8MoEMethod: ...@@ -850,6 +875,52 @@ class Fp8MoEMethod:
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
def torch_w8a8_block_fp8_moe(
self, a, w1, w2, w1_s, w2_s, topk_weight, topk_ids, block_shape
):
from sglang.srt.layers.activation import SiluAndMul
"""This function performs fused moe with block-wise quantization using native torch."""
B, D = a.shape
topk = topk_ids.shape[-1]
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
_, block_k = block_shape[0], block_shape[1]
a_q, a_s = native_per_token_group_quant_fp8(a, block_k)
# NOTE(HandH1998): Since "index_cuda" not implemented for 'Float8_e4m3fn', we need to cast `float8`` to `float32``.
a_q = a_q.to(torch.float32)
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter_out = native_w8a8_block_fp8_matmul(
a_q[mask],
w1[i],
a_s[mask],
w1_s[i],
block_shape,
output_dtype=a.dtype,
)
act_out = SiluAndMul().forward_native(inter_out)
act_out_q, act_out_s = native_per_token_group_quant_fp8(
act_out, block_k
)
act_out = act_out.to(torch.float32)
out[mask] = native_w8a8_block_fp8_matmul(
act_out_q,
w2[i],
act_out_s,
w2_s[i],
block_shape,
output_dtype=a.dtype,
)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -923,7 +994,7 @@ class Fp8MoEMethod: ...@@ -923,7 +994,7 @@ class Fp8MoEMethod:
layer.w13_weight_scale1, layer.w13_weight_scale1,
layer.w2_weight_scale1, layer.w2_weight_scale1,
) )
else: elif _is_cuda:
# Expert fusion with FP8 quantization # Expert fusion with FP8 quantization
return fused_experts( return fused_experts(
x, x,
...@@ -950,6 +1021,24 @@ class Fp8MoEMethod: ...@@ -950,6 +1021,24 @@ class Fp8MoEMethod:
no_combine=no_combine, no_combine=no_combine,
) )
# for CPU and other accelerators, fallback to native path
return self.torch_w8a8_block_fp8_moe(
a=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
w1_s=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_s=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
topk_weight=topk_weights,
topk_ids=topk_ids,
block_shape=self.quant_config.weight_block_size,
)
class Fp8KVCacheMethod(BaseKVCacheMethod): class Fp8KVCacheMethod(BaseKVCacheMethod):
""" """
......
...@@ -28,10 +28,13 @@ from sglang.srt.utils import ( ...@@ -28,10 +28,13 @@ from sglang.srt.utils import (
get_device_name, get_device_name,
is_cuda, is_cuda,
is_hip, is_hip,
is_triton_available,
supports_custom_op, supports_custom_op,
) )
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda()
_is_triton = is_triton_available()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -162,6 +165,34 @@ def _per_token_group_quant_fp8_colmajor( ...@@ -162,6 +165,34 @@ def _per_token_group_quant_fp8_colmajor(
tl.store(y_s_ptr, y_s) tl.store(y_s_ptr, y_s)
def native_per_token_group_quant_fp8(
x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn
):
"""Function to perform per-token-group quantization on an input tensor `x` using native torch.
It converts the tensor values into float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Note that only `torch.float8_e4m3fn` is supported for now.
"""
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
return x_q, x_s
def per_token_group_quant_fp8( def per_token_group_quant_fp8(
x: torch.Tensor, x: torch.Tensor,
group_size: int, group_size: int,
...@@ -232,6 +263,7 @@ def per_token_group_quant_fp8( ...@@ -232,6 +263,7 @@ def per_token_group_quant_fp8(
num_stages=num_stages, num_stages=num_stages,
) )
else: else:
if _is_triton:
_per_token_group_quant_fp8[(M,)]( _per_token_group_quant_fp8[(M,)](
x, x,
x_q, x_q,
...@@ -245,6 +277,8 @@ def per_token_group_quant_fp8( ...@@ -245,6 +277,8 @@ def per_token_group_quant_fp8(
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
) )
else:
x_q, x_s = native_per_token_group_quant_fp8(x, group_size)
return x_q, x_s return x_q, x_s
...@@ -691,6 +725,61 @@ def get_w8a8_block_fp8_configs( ...@@ -691,6 +725,61 @@ def get_w8a8_block_fp8_configs(
return None return None
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""This function performs matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N,)
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
B_tiles = [
[
B[
j * block_n : min((j + 1) * block_n, N),
i * block_k : min((i + 1) * block_k, K),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
return C
def w8a8_block_fp8_matmul( def w8a8_block_fp8_matmul(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
...@@ -715,6 +804,7 @@ def w8a8_block_fp8_matmul( ...@@ -715,6 +804,7 @@ def w8a8_block_fp8_matmul(
Returns: Returns:
torch.Tensor: The result of matmul. torch.Tensor: The result of matmul.
""" """
if _is_triton: # pragma: no cover
assert len(block_size) == 2 assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1] block_n, block_k = block_size[0], block_size[1]
...@@ -750,7 +840,8 @@ def w8a8_block_fp8_matmul( ...@@ -750,7 +840,8 @@ def w8a8_block_fp8_matmul(
def grid(META): def grid(META):
return ( return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), triton.cdiv(M, META["BLOCK_SIZE_M"])
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
) )
# Use manually unrolledx4 kernel on AMD GPU when the grid size is small. # Use manually unrolledx4 kernel on AMD GPU when the grid size is small.
...@@ -796,5 +887,7 @@ def w8a8_block_fp8_matmul( ...@@ -796,5 +887,7 @@ def w8a8_block_fp8_matmul(
Bs.stride(0), Bs.stride(0),
**config, **config,
) )
else:
C = native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype)
return C return C
import logging import logging
import sys
from fractions import Fraction from fractions import Fraction
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
...@@ -8,6 +9,7 @@ from vllm.scalar_type import scalar_types ...@@ -8,6 +9,7 @@ from vllm.scalar_type import scalar_types
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.utils import get_device_capability
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -90,8 +92,21 @@ class GPTQConfig(QuantizationConfig): ...@@ -90,8 +92,21 @@ class GPTQConfig(QuantizationConfig):
@classmethod @classmethod
# Need to figure it out # Need to figure it out
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 60 return 60
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 60
# Vendors can update
return False
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"] return ["quantize_config.json"]
...@@ -209,8 +224,21 @@ class GPTQMarlinConfig(QuantizationConfig): ...@@ -209,8 +224,21 @@ class GPTQMarlinConfig(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 80 return 80
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 80
# Vendors can update
return False
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"] return ["quantize_config.json"]
...@@ -371,8 +399,21 @@ class MarlinConfig(QuantizationConfig): ...@@ -371,8 +399,21 @@ class MarlinConfig(QuantizationConfig):
@classmethod @classmethod
# Need to figure it out # Need to figure it out
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 80 return 80
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 80
# Vendors can update
return False
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"] return ["quantize_config.json"]
......
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py
import logging import logging
import sys
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import torch import torch
...@@ -20,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -20,6 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear
from sglang.srt.utils import get_device_capability
# Initialize logger for the module # Initialize logger for the module
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -52,7 +54,20 @@ class ModelOptFp8Config(QuantizationConfig): ...@@ -52,7 +54,20 @@ class ModelOptFp8Config(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 89 # Minimum hardware capability (e.g., Hopper GPUs). if hasattr(torch, "cuda") and torch.cuda.is_available():
return 89
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 89
# Vendors can update
return False
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
......
...@@ -14,7 +14,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -14,7 +14,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
cutlass_fp8_supported, cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz, normalize_e4m3fn_to_e4m3fnuz,
) )
from sglang.srt.utils import is_hip from sglang.srt.utils import get_device_capability, is_hip
_is_hip = is_hip() _is_hip = is_hip()
...@@ -35,8 +35,21 @@ class W8A8Fp8Config(QuantizationConfig): ...@@ -35,8 +35,21 @@ class W8A8Fp8Config(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 89 return 89
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> bool:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 89
# Vendors can update
return False
@classmethod @classmethod
def get_name(self) -> str: def get_name(self) -> str:
return "w8a8_fp8" return "w8a8_fp8"
......
import sys
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
...@@ -18,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -18,6 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.utils import get_device_capability
class W8A8Int8Config(QuantizationConfig): class W8A8Int8Config(QuantizationConfig):
...@@ -36,8 +38,21 @@ class W8A8Int8Config(QuantizationConfig): ...@@ -36,8 +38,21 @@ class W8A8Int8Config(QuantizationConfig):
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
if hasattr(torch, "cuda") and torch.cuda.is_available():
return 75 return 75
# Vendors can update
return sys.maxsize
@classmethod
def get_availability(cls) -> int:
major, minor = get_device_capability()
if hasattr(torch, "cuda") and torch.cuda.is_available():
return major * 10 + minor > 75
# Vendors can update
return False
@classmethod @classmethod
def get_name(self) -> str: def get_name(self) -> str:
return "w8a8_int8" return "w8a8_int8"
......
...@@ -69,6 +69,7 @@ class RotaryEmbedding(CustomOp): ...@@ -69,6 +69,7 @@ class RotaryEmbedding(CustomOp):
base: int, base: int,
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
device: str,
) -> None: ) -> None:
super().__init__() super().__init__()
self.head_size = head_size self.head_size = head_size
...@@ -77,6 +78,7 @@ class RotaryEmbedding(CustomOp): ...@@ -77,6 +78,7 @@ class RotaryEmbedding(CustomOp):
self.base = base self.base = base
self.is_neox_style = is_neox_style self.is_neox_style = is_neox_style
self.dtype = dtype self.dtype = dtype
self.device = device
cache = self._compute_cos_sin_cache() cache = self._compute_cos_sin_cache()
# NOTE(ByronHsu): cache needs to be in FP32 for numerical stability # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
...@@ -283,12 +285,19 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding): ...@@ -283,12 +285,19 @@ class LinearScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool, is_neox_style: bool,
scaling_factors: Union[List[float], float], scaling_factors: Union[List[float], float],
dtype: torch.dtype, dtype: torch.dtype,
device: str,
) -> None: ) -> None:
if isinstance(scaling_factors, float): if isinstance(scaling_factors, float):
scaling_factors = [scaling_factors] scaling_factors = [scaling_factors]
self.scaling_factors: List[float] = scaling_factors # noqa self.scaling_factors: List[float] = scaling_factors # noqa
super().__init__( super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
) )
# Lazy initialized. # Lazy initialized.
self._scaling_factor_to_offset: Dict[float, int] self._scaling_factor_to_offset: Dict[float, int]
...@@ -347,10 +356,17 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): ...@@ -347,10 +356,17 @@ class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool, is_neox_style: bool,
scaling_factor: float, scaling_factor: float,
dtype: torch.dtype, dtype: torch.dtype,
device: str,
) -> None: ) -> None:
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
super().__init__( super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
) )
def _compute_cos_sin_cache(self) -> torch.Tensor: def _compute_cos_sin_cache(self) -> torch.Tensor:
...@@ -434,6 +450,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -434,6 +450,7 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool, is_neox_style: bool,
scaling_factor: float, scaling_factor: float,
dtype: torch.dtype, dtype: torch.dtype,
device: str,
*, *,
extrapolation_factor: float = 1, extrapolation_factor: float = 1,
attn_factor: float = 1, attn_factor: float = 1,
...@@ -448,7 +465,13 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding): ...@@ -448,7 +465,13 @@ class YaRNScalingRotaryEmbedding(RotaryEmbedding):
# Get n-d magnitude scaling corrected for interpolation # Get n-d magnitude scaling corrected for interpolation
self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor) self.mscale = float(_yarn_get_mscale(self.scaling_factor) * attn_factor)
super().__init__( super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
) )
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
...@@ -645,6 +668,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -645,6 +668,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
is_neox_style: bool, is_neox_style: bool,
scaling_factor: float, scaling_factor: float,
dtype: torch.dtype, dtype: torch.dtype,
device: str,
*, *,
extrapolation_factor: float = 1, extrapolation_factor: float = 1,
attn_factor: float = 1, attn_factor: float = 1,
...@@ -652,7 +676,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -652,7 +676,6 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
beta_slow: int = 1, beta_slow: int = 1,
mscale: float = 1, mscale: float = 1,
mscale_all_dim: float = 0, mscale_all_dim: float = 0,
device: Optional[str] = "cuda",
) -> None: ) -> None:
self.scaling_factor = scaling_factor self.scaling_factor = scaling_factor
self.extrapolation_factor = extrapolation_factor self.extrapolation_factor = extrapolation_factor
...@@ -665,9 +688,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -665,9 +688,14 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
/ yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
* attn_factor * attn_factor
) )
self.device = device
super().__init__( super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
) )
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
...@@ -762,6 +790,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding): ...@@ -762,6 +790,7 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
base: int, base: int,
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
device: str,
scaling_factor: float, scaling_factor: float,
low_freq_factor: float, low_freq_factor: float,
high_freq_factor: float, high_freq_factor: float,
...@@ -772,7 +801,13 @@ class Llama3RotaryEmbedding(RotaryEmbedding): ...@@ -772,7 +801,13 @@ class Llama3RotaryEmbedding(RotaryEmbedding):
self.high_freq_factor = high_freq_factor self.high_freq_factor = high_freq_factor
self.orig_max_position = orig_max_position self.orig_max_position = orig_max_position
super().__init__( super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
str,
) )
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
...@@ -810,10 +845,17 @@ class MRotaryEmbedding(RotaryEmbedding): ...@@ -810,10 +845,17 @@ class MRotaryEmbedding(RotaryEmbedding):
base: int, base: int,
is_neox_style: bool, is_neox_style: bool,
dtype: torch.dtype, dtype: torch.dtype,
device: str,
mrope_section: Optional[List[int]] = None, mrope_section: Optional[List[int]] = None,
) -> None: ) -> None:
super().__init__( super().__init__(
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype head_size,
rotary_dim,
max_position_embeddings,
base,
is_neox_style,
dtype,
device,
) )
self.mrope_section = mrope_section self.mrope_section = mrope_section
...@@ -1003,9 +1045,14 @@ def get_rope( ...@@ -1003,9 +1045,14 @@ def get_rope(
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0, partial_rotary_factor: float = 1.0,
device: str = None,
) -> RotaryEmbedding: ) -> RotaryEmbedding:
if dtype is None: if dtype is None:
dtype = torch.get_default_dtype() dtype = torch.get_default_dtype()
if device is None:
from sglang.srt.managers.schedule_batch import global_server_args_dict
device = global_server_args_dict["device"]
if rope_scaling is not None: if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls # Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = { rope_scaling_tuple = {
...@@ -1030,7 +1077,7 @@ def get_rope( ...@@ -1030,7 +1077,7 @@ def get_rope(
if rope_scaling is None: if rope_scaling is None:
rotary_emb = RotaryEmbedding( rotary_emb = RotaryEmbedding(
head_size, rotary_dim, max_position, base, is_neox_style, dtype head_size, rotary_dim, max_position, base, is_neox_style, dtype, device
) )
else: else:
if "rope_type" in rope_scaling: if "rope_type" in rope_scaling:
...@@ -1052,6 +1099,7 @@ def get_rope( ...@@ -1052,6 +1099,7 @@ def get_rope(
base, base,
is_neox_style, is_neox_style,
dtype, dtype,
device,
scaling_factor, scaling_factor,
low_freq_factor, low_freq_factor,
high_freq_factor, high_freq_factor,
...@@ -1066,6 +1114,7 @@ def get_rope( ...@@ -1066,6 +1114,7 @@ def get_rope(
base, base,
is_neox_style, is_neox_style,
dtype, dtype,
device,
mrope_section=rope_scaling["mrope_section"], mrope_section=rope_scaling["mrope_section"],
) )
else: else:
...@@ -1076,6 +1125,7 @@ def get_rope( ...@@ -1076,6 +1125,7 @@ def get_rope(
base, base,
is_neox_style, is_neox_style,
dtype, dtype,
device,
) )
elif scaling_type == "linear": elif scaling_type == "linear":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
...@@ -1087,6 +1137,7 @@ def get_rope( ...@@ -1087,6 +1137,7 @@ def get_rope(
is_neox_style, is_neox_style,
scaling_factor, scaling_factor,
dtype, dtype,
device,
) )
elif scaling_type == "dynamic": elif scaling_type == "dynamic":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
...@@ -1098,6 +1149,7 @@ def get_rope( ...@@ -1098,6 +1149,7 @@ def get_rope(
is_neox_style, is_neox_style,
scaling_factor, scaling_factor,
dtype, dtype,
device,
) )
elif scaling_type == "yarn": elif scaling_type == "yarn":
scaling_factor = rope_scaling["factor"] scaling_factor = rope_scaling["factor"]
...@@ -1116,6 +1168,7 @@ def get_rope( ...@@ -1116,6 +1168,7 @@ def get_rope(
is_neox_style, is_neox_style,
scaling_factor, scaling_factor,
dtype, dtype,
device,
**extra_kwargs, **extra_kwargs,
) )
elif scaling_type == "deepseek_yarn": elif scaling_type == "deepseek_yarn":
...@@ -1143,6 +1196,7 @@ def get_rope( ...@@ -1143,6 +1196,7 @@ def get_rope(
is_neox_style, is_neox_style,
scaling_factor, scaling_factor,
dtype, dtype,
device,
**extra_kwargs, **extra_kwargs,
) )
elif scaling_type == "longrope": elif scaling_type == "longrope":
...@@ -1253,9 +1307,7 @@ def get_rope_wrapper( ...@@ -1253,9 +1307,7 @@ def get_rope_wrapper(
rope_scaling: Optional[Dict[str, Any]] = None, rope_scaling: Optional[Dict[str, Any]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
partial_rotary_factor: float = 1.0, partial_rotary_factor: float = 1.0,
device: Optional[str] = None,
): ):
if device != "cpu":
return get_rope( return get_rope(
head_size, head_size,
rotary_dim, rotary_dim,
...@@ -1266,15 +1318,3 @@ def get_rope_wrapper( ...@@ -1266,15 +1318,3 @@ def get_rope_wrapper(
dtype, dtype,
partial_rotary_factor, partial_rotary_factor,
) )
return get_rope_cpu(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
rope_scaling,
dtype,
partial_rotary_factor,
device,
)
...@@ -35,6 +35,8 @@ from sglang.srt.distributed import ( ...@@ -35,6 +35,8 @@ from sglang.srt.distributed import (
set_custom_all_reduce, set_custom_all_reduce,
) )
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.dp_attention import ( from sglang.srt.layers.dp_attention import (
get_attention_tp_group, get_attention_tp_group,
get_attention_tp_size, get_attention_tp_size,
......
...@@ -108,6 +108,8 @@ def _get_quantization_config( ...@@ -108,6 +108,8 @@ def _get_quantization_config(
quant_config = get_quant_config(model_config, load_config) quant_config = get_quant_config(model_config, load_config)
major, minor = get_device_capability() major, minor = get_device_capability()
if not hasattr(quant_config, "get_availability"):
# Update VLLM to support get_available
if major is not None and minor is not None: if major is not None and minor is not None:
assert 0 <= minor < 10 assert 0 <= minor < 10
capability = major * 10 + minor capability = major * 10 + minor
...@@ -118,6 +120,14 @@ def _get_quantization_config( ...@@ -118,6 +120,14 @@ def _get_quantization_config(
f"Minimum capability: {quant_config.get_min_capability()}. " f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}." f"Current capability: {capability}."
) )
else:
if not quant_config.get_availability():
raise ValueError(
f"The quantization method {model_config.quantization} "
"is not supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {major, minor}."
)
supported_dtypes = quant_config.get_supported_act_dtypes() supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes: if model_config.dtype not in supported_dtypes:
raise ValueError( raise ValueError(
......
...@@ -55,7 +55,7 @@ from sglang.srt.layers.quantization.int8_utils import ( ...@@ -55,7 +55,7 @@ from sglang.srt.layers.quantization.int8_utils import (
block_dequant as int8_block_dequant, block_dequant as int8_block_dequant,
) )
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper from sglang.srt.layers.rotary_embedding import get_rope_wrapper
from sglang.srt.layers.vocab_parallel_embedding import ( from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
...@@ -305,7 +305,6 @@ class DeepseekV2Attention(nn.Module): ...@@ -305,7 +305,6 @@ class DeepseekV2Attention(nn.Module):
base=rope_theta, base=rope_theta,
rope_scaling=rope_scaling, rope_scaling=rope_scaling,
is_neox_style=False, is_neox_style=False,
device=global_server_args_dict["device"],
) )
if rope_scaling: if rope_scaling:
...@@ -501,7 +500,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -501,7 +500,7 @@ class DeepseekV2AttentionMLA(nn.Module):
if rope_scaling: if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn" rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope( self.rotary_emb = get_rope_wrapper(
qk_rope_head_dim, qk_rope_head_dim,
rotary_dim=qk_rope_head_dim, rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings, max_position=max_position_embeddings,
...@@ -646,19 +645,20 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -646,19 +645,20 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
if self.w_kc.dtype == torch.float8_e4m3fnuz: if self.w_kc.dtype == torch.float8_e4m3fnuz: # hip only
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
q_nope_out = torch.bmm( q_nope_out = torch.bmm(
q_nope.to(torch.bfloat16).transpose(0, 1), q_nope.to(torch.bfloat16).transpose(0, 1),
self.w_kc.to(torch.bfloat16) * self.w_scale, self.w_kc.to(torch.bfloat16) * self.w_scale,
) )
elif self.w_kc.dtype == torch.float8_e4m3fn: elif self.w_kc.dtype == torch.float8_e4m3fn and is_cuda_available():
q_nope_val, q_nope_scale = input_to_float8( q_nope_val, q_nope_scale = input_to_float8(
q_nope.transpose(0, 1), torch.float8_e4m3fn q_nope.transpose(0, 1), torch.float8_e4m3fn
) )
q_nope_out = bmm_fp8( q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
) )
else: else:
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1) q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1)
...@@ -677,13 +677,13 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -677,13 +677,13 @@ class DeepseekV2AttentionMLA(nn.Module):
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch) attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.w_vc.dtype == torch.float8_e4m3fnuz: if self.w_vc.dtype == torch.float8_e4m3fnuz or not is_cuda_available():
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
attn_bmm_output = torch.bmm( attn_bmm_output = torch.bmm(
attn_output.to(torch.bfloat16).transpose(0, 1), attn_output.to(torch.bfloat16).transpose(0, 1),
self.w_vc.to(torch.bfloat16) * self.w_scale, self.w_vc.to(torch.bfloat16) * self.w_scale,
) )
elif self.w_vc.dtype == torch.float8_e4m3fn: elif self.w_vc.dtype == torch.float8_e4m3fn and is_cuda_available():
attn_output_val, attn_output_scale = input_to_float8( attn_output_val, attn_output_scale = input_to_float8(
attn_output.transpose(0, 1), torch.float8_e4m3fn attn_output.transpose(0, 1), torch.float8_e4m3fn
) )
...@@ -694,6 +694,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -694,6 +694,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_scale, self.w_scale,
torch.bfloat16, torch.bfloat16,
) )
else: else:
attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc) attn_bmm_output = torch.bmm(attn_output.transpose(0, 1), self.w_vc)
attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2) attn_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
......
...@@ -126,6 +126,14 @@ def is_cuda_available(): ...@@ -126,6 +126,14 @@ def is_cuda_available():
return is_cuda() return is_cuda()
def is_triton_available():
if is_cuda() or is_xpu() or is_hip():
return get_bool_env_var("TRITON_AVAILABLE", default="true")
else:
# update once CPU/HPU supports triton
return False
def enable_show_time_cost(): def enable_show_time_cost():
global show_time_cost global show_time_cost
show_time_cost = True show_time_cost = True
...@@ -1136,6 +1144,7 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]: ...@@ -1136,6 +1144,7 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
major, minor = None, None major, minor = None, None
if hasattr(torch, "cuda") and torch.cuda.is_available(): if hasattr(torch, "cuda") and torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability(device_id) major, minor = torch.cuda.get_device_capability(device_id)
assert 0 <= minor < 10
if hasattr(torch, "xpu") and torch.xpu.is_available(): if hasattr(torch, "xpu") and torch.xpu.is_available():
major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split( major, minor, *_ = torch.xpu.get_device_capability(device_id)["version"].split(
......
...@@ -7,6 +7,8 @@ import torch ...@@ -7,6 +7,8 @@ import torch
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
native_per_token_group_quant_fp8,
native_w8a8_block_fp8_matmul,
per_token_group_quant_fp8, per_token_group_quant_fp8,
static_quant_fp8, static_quant_fp8,
w8a8_block_fp8_matmul, w8a8_block_fp8_matmul,
...@@ -15,35 +17,6 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ...@@ -15,35 +17,6 @@ from sglang.srt.layers.quantization.fp8_kernel import (
_is_cuda = torch.cuda.is_available() and torch.version.cuda _is_cuda = torch.cuda.is_available() and torch.version.cuda
# For test
def native_per_token_group_quant_fp8(
x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn
):
"""Function to perform per-token-group quantization on an input tensor `x` using native torch.
It converts the tensor values into float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Note that only `torch.float8_e4m3fn` is supported for now.
"""
assert (
x.shape[-1] % group_size == 0
), "the last dimension of `x` cannot be divisible by `group_size`"
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))
return x_q, x_s
class TestPerTokenGroupQuantFP8(unittest.TestCase): class TestPerTokenGroupQuantFP8(unittest.TestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32] DTYPES = [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048] NUM_TOKENS = [7, 83, 2048]
...@@ -154,62 +127,6 @@ class TestStaticQuantFP8(unittest.TestCase): ...@@ -154,62 +127,6 @@ class TestStaticQuantFP8(unittest.TestCase):
self._static_quant_fp8(*params) self._static_quant_fp8(*params)
# For test
def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""This function performs matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N,)
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
B_tiles = [
[
B[
j * block_n : min((j + 1) * block_n, N),
i * block_k : min((i + 1) * block_k, K),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
return C
class TestW8A8BlockFP8Matmul(unittest.TestCase): class TestW8A8BlockFP8Matmul(unittest.TestCase):
if not _is_cuda: if not _is_cuda:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment