Commit 5663e01d authored by wanglong3's avatar wanglong3
Browse files

Switch default w8a8 gemm impl to blaslt.

parent e5572b2a
......@@ -207,6 +207,7 @@ if TYPE_CHECKING:
VLLM_USE_FUSED_QA_KVA_GEMM: bool = False
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM:bool = True
VLLM_W8A8_BACKEND: int = 3
def get_default_cache_root():
return os.getenv(
......@@ -1335,6 +1336,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
int(os.getenv("VLLM_DISABLE_SHARED_EXPERTS_STREAM", "1"))
),
# W8A8 GEMM backend selection for vLLM quantized models.
# lightop/triton: 1
# cutlass: 2 (will remove in the future)
# blaslt: 3 (default)
# rocblas: others
"VLLM_W8A8_BACKEND": lambda: int(os.getenv("VLLM_W8A8_BACKEND", "3")),
}
# --8<-- [end:env-vars-definition]
......@@ -1399,6 +1407,7 @@ def compute_hash() -> str:
"VLLM_DP_SIZE",
"VLLM_USE_STANDALONE_COMPILE",
"VLLM_FUSED_MOE_CHUNK_SIZE",
"VLLM_W8A8_BACKEND",
]
for key in environment_variables_to_hash:
if key in environment_variables:
......
......@@ -20,6 +20,7 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
PerTensorScaleParameter)
from vllm.utils import W8a8GetCacheJSON
from vllm import _custom_ops as ops
import vllm.envs as envs
logger = init_logger(__name__)
......@@ -31,8 +32,8 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_symmetric: bool):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.tritonsingleton = W8a8GetCacheJSON()
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
self.input_symmetric = input_symmetric
@classmethod
......
......@@ -92,8 +92,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.tritonsingleton = W8a8GetCacheJSON()
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0]
......
......@@ -58,6 +58,7 @@ from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
import vllm.envs as envs
FalconConfig = Union[HF_FalconConfig, RWConfig]
......@@ -393,7 +394,7 @@ class FalconModel(nn.Module):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.word_embeddings(input_ids)
......
......@@ -31,6 +31,7 @@ from typing import Optional, Union
import torch
from torch import nn
from transformers import Glm4Config
import vllm.envs as envs
class MultiModalConfigProxy:
......@@ -332,7 +333,7 @@ class Glm4ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
......
......@@ -453,7 +453,7 @@ class Qwen3MoeModel(nn.Module):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
......
......@@ -37,6 +37,7 @@ from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper,
from vllm.utils import W8a8GetCacheJSON
from vllm import _custom_ops as ops
from vllm.model_executor.utils import pad_weight, gemm_bank_conf
import vllm.envs as envs
class TeleChat2Model(LlamaModel):
......@@ -66,8 +67,7 @@ class TeleChat2Model(LlamaModel):
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment