Unverified Commit 99c92ff2 authored by jacky.cheng's avatar jacky.cheng Committed by GitHub
Browse files

[AMD] Support a new flag to disable quant on parallelLinear layer if required (#11811)

parent 6ade6a02
...@@ -158,6 +158,7 @@ class Envs: ...@@ -158,6 +158,7 @@ class Envs:
# AMD & ROCm # AMD & ROCm
SGLANG_USE_AITER = EnvBool(False) SGLANG_USE_AITER = EnvBool(False)
SGLANG_ROCM_FUSED_DECODE_MLA = EnvBool(False) SGLANG_ROCM_FUSED_DECODE_MLA = EnvBool(False)
SGLANG_ROCM_DISABLE_LINEARQUANT = EnvBool(False)
# Quantization # Quantization
SGLANG_INT4_WEIGHT = EnvBool(False) SGLANG_INT4_WEIGHT = EnvBool(False)
......
...@@ -32,7 +32,7 @@ from sglang.srt.layers.parameter import ( ...@@ -32,7 +32,7 @@ from sglang.srt.layers.parameter import (
) )
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.utils import pad_or_narrow_weight from sglang.srt.layers.utils import pad_or_narrow_weight
from sglang.srt.utils import is_cpu, is_npu, set_weight_attrs from sglang.srt.utils import get_bool_env_var, is_cpu, is_hip, is_npu, set_weight_attrs
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
...@@ -40,6 +40,11 @@ if TYPE_CHECKING: ...@@ -40,6 +40,11 @@ if TYPE_CHECKING:
QuantizeMethodBase, QuantizeMethodBase,
) )
_is_hip = is_hip()
_disable_hip_linear_quant = _is_hip and get_bool_env_var(
"SGLANG_ROCM_DISABLE_LINEARQUANT"
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
WEIGHT_LOADER_V2_SUPPORTED = [ WEIGHT_LOADER_V2_SUPPORTED = [
...@@ -824,6 +829,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -824,6 +829,7 @@ class QKVParallelLinear(ColumnParallelLinear):
self.num_kv_heads * self.head_size * tp_size, # v_proj self.num_kv_heads * self.head_size * tp_size, # v_proj
] ]
self.use_presharded_weights = load_presharded_attn self.use_presharded_weights = load_presharded_attn
quant_config = None if _disable_hip_linear_quant else quant_config
super().__init__( super().__init__(
input_size=input_size, input_size=input_size,
...@@ -1225,6 +1231,7 @@ class RowParallelLinear(LinearBase): ...@@ -1225,6 +1231,7 @@ class RowParallelLinear(LinearBase):
tp_size: Optional[int] = None, tp_size: Optional[int] = None,
use_presharded_weights: bool = False, use_presharded_weights: bool = False,
): ):
quant_config = None if _disable_hip_linear_quant else quant_config
super().__init__( super().__init__(
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment