Commit 8fc55263 authored by maxiao1's avatar maxiao1
Browse files

Merge branch 'v0.5.4_dev_maxiao' into 'v0.5.4_dev'

适配w8a8模型

See merge request OpenDAS/sglang!1
parents f6528b74 eb4ba1c2
...@@ -615,6 +615,7 @@ class ModelConfig: ...@@ -615,6 +615,7 @@ class ModelConfig:
"quark", "quark",
"mxfp4", "mxfp4",
"slimquant_w4a8_marlin", "slimquant_w4a8_marlin",
"w8a8_int8",
] ]
optimized_quantization_methods = [ optimized_quantization_methods = [
"fp8", "fp8",
......
...@@ -14,9 +14,10 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ...@@ -14,9 +14,10 @@ from sglang.srt.layers.quantization.fp8_kernel import (
) )
from sglang.srt.layers.quantization.int8_kernel import ( from sglang.srt.layers.quantization.int8_kernel import (
per_token_group_quant_int8, per_token_group_quant_int8,
per_token_quant_int8, # per_token_quant_int8,
sglang_per_token_group_quant_int8, sglang_per_token_group_quant_int8,
) )
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from sglang.srt.utils import ( from sglang.srt.utils import (
cpu_has_amx_support, cpu_has_amx_support,
get_bool_env_var, get_bool_env_var,
......
...@@ -22,7 +22,8 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -22,7 +22,8 @@ from sglang.srt.layers.quantization.base_config import (
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 # from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.utils import ( from sglang.srt.utils import (
apply_module_patch, apply_module_patch,
...@@ -39,6 +40,8 @@ if TYPE_CHECKING: ...@@ -39,6 +40,8 @@ if TYPE_CHECKING:
CombineInput, CombineInput,
StandardDispatchOutput, StandardDispatchOutput,
) )
from lmslim import quant_ops
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_cpu_amx_available = cpu_has_amx_support() _is_cpu_amx_available = cpu_has_amx_support()
...@@ -405,7 +408,7 @@ class W8A8Int8LinearMethod(LinearMethodBase): ...@@ -405,7 +408,7 @@ class W8A8Int8LinearMethod(LinearMethodBase):
x_scale_2d = x_scale.view(-1, x_scale.shape[-1]) x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]] output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
output = int8_scaled_mm( output = quant_ops.triton_scaled_mm(
x_q_2d, x_q_2d,
layer.weight, layer.weight,
x_scale_2d, x_scale_2d,
......
...@@ -203,7 +203,7 @@ _is_xpu_xmx_available = xpu_has_xmx_support() ...@@ -203,7 +203,7 @@ _is_xpu_xmx_available = xpu_has_xmx_support()
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None) SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
# Detect stragger ranks in model loading # Detect stragger ranks in model loading
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300 UNBALANCED_MODEL_LOADING_TIMEOUT_S = 3600
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077) # the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3 MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment