Commit 04343d9d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev-tmp' into 'v0.11.0-dev'

The gfx928 architecture forces the use of the Triton gemm.

See merge request dcutoolkit/deeplearing/vllm!428
parents 6de849de f551bd1d
...@@ -6,6 +6,7 @@ import json ...@@ -6,6 +6,7 @@ import json
import os import os
import sys import sys
import tempfile import tempfile
import torch
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -1704,7 +1705,9 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1704,7 +1705,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
# cutlass: 2 (will remove in the future) # cutlass: 2 (will remove in the future)
# blaslt: 3 (default) # blaslt: 3 (default)
# rocblas: others # rocblas: others
"VLLM_W8A8_BACKEND": lambda: int(os.getenv("VLLM_W8A8_BACKEND", "3")), "VLLM_W8A8_BACKEND": lambda: int(
1 if "gfx928" in torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] else os.getenv("VLLM_W8A8_BACKEND", "3")),
# Force using Triton MoE path (disable Marlin W16A16 MoE). # Force using Triton MoE path (disable Marlin W16A16 MoE).
"VLLM_USE_MOE_W16A16_TRITON": "VLLM_USE_MOE_W16A16_TRITON":
lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in
......
...@@ -92,8 +92,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -92,8 +92,8 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: SlimQuantW4A8Int8Config): def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
self.quantization_config = quantization_config self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton = W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy = envs.VLLM_W8A8_BACKEND
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0] n=layer.weight.shape[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment