Commit 5663e01d authored by wanglong3's avatar wanglong3
Browse files

Switch default w8a8 gemm impl to blaslt.

parent e5572b2a
...@@ -207,6 +207,7 @@ if TYPE_CHECKING: ...@@ -207,6 +207,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_QA_KVA_GEMM: bool = False VLLM_USE_FUSED_QA_KVA_GEMM: bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM:bool = True VLLM_DISABLE_SHARED_EXPERTS_STREAM:bool = True
VLLM_W8A8_BACKEND: int = 3
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1335,6 +1336,13 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1335,6 +1336,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "1")) int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "1"))
), ),
# W8A8 GEMM backend selection for vLLM quantized models.
# lightop/triton: 1
# cutlass: 2 (will remove in the future)
# blaslt: 3 (default)
# rocblas: others
"VLLM_W8A8_BACKEND": lambda: int(os.getenv("VLLM_W8A8_BACKEND", "3")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
...@@ -1399,6 +1407,7 @@ def compute_hash() -> str: ...@@ -1399,6 +1407,7 @@ def compute_hash() -> str:
"VLLM_DP_SIZE", "VLLM_DP_SIZE",
"VLLM_USE_STANDALONE_COMPILE", "VLLM_USE_STANDALONE_COMPILE",
"VLLM_FUSED_MOE_CHUNK_SIZE", "VLLM_FUSED_MOE_CHUNK_SIZE",
"VLLM_W8A8_BACKEND",
] ]
for key in environment_variables_to_hash: for key in environment_variables_to_hash:
if key in environment_variables: if key in environment_variables:
......
...@@ -20,6 +20,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter, ...@@ -20,6 +20,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
PerTensorScaleParameter) PerTensorScaleParameter)
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
import vllm.envs as envs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -31,8 +32,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -31,8 +32,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_symmetric: bool): input_symmetric: bool):
self.strategy = strategy self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton = W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
self.input_symmetric = input_symmetric self.input_symmetric = input_symmetric
@classmethod @classmethod
......
...@@ -92,8 +92,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -92,8 +92,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: SlimQuantW4A8Int8Config): def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
self.quantization_config = quantization_config self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton = W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0] n=layer.weight.shape[0]
......
...@@ -58,6 +58,7 @@ from .utils import (AutoWeightsLoader, is_pp_missing_parameter, ...@@ -58,6 +58,7 @@ from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
import vllm.envs as envs
FalconConfig = Union[HF_FalconConfig, RWConfig] FalconConfig = Union[HF_FalconConfig, RWConfig]
...@@ -393,7 +394,7 @@ class FalconModel(nn.Module): ...@@ -393,7 +394,7 @@ class FalconModel(nn.Module):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids) return self.word_embeddings(input_ids)
......
...@@ -31,6 +31,7 @@ from typing import Optional, Union ...@@ -31,6 +31,7 @@ from typing import Optional, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Glm4Config from transformers import Glm4Config
import vllm.envs as envs
class MultiModalConfigProxy: class MultiModalConfigProxy:
...@@ -332,7 +333,7 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -332,7 +333,7 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
......
...@@ -453,7 +453,7 @@ class Qwen3MoeModel(nn.Module): ...@@ -453,7 +453,7 @@ class Qwen3MoeModel(nn.Module):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids) return self.embed_tokens(input_ids)
......
...@@ -37,6 +37,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, ...@@ -37,6 +37,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.model_executor.utils import pad_weight, gemm_bank_conf
import vllm.envs as envs
class TeleChat2Model(LlamaModel): class TeleChat2Model(LlamaModel):
...@@ -66,8 +67,7 @@ class TeleChat2Model(LlamaModel): ...@@ -66,8 +67,7 @@ class TeleChat2Model(LlamaModel):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def load_weights(self, weights: Iterable[tuple[str, def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]: torch.Tensor]]) -> set[str]:
stacked_params_mapping = [ stacked_params_mapping = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment