Commit dcec1430 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-w8a8-new' into 'v0.9.2-dev'

V0.9.2 dev w8a8 new

See merge request dcutoolkit/deeplearing/vllm!173
parents 513f17a4 333104ab
...@@ -163,7 +163,7 @@ if TYPE_CHECKING: ...@@ -163,7 +163,7 @@ if TYPE_CHECKING:
VLLM_ENABLE_MOE_FUSED_GATE: bool = False VLLM_ENABLE_MOE_FUSED_GATE: bool = False
VLLM_USE_FLASH_ATTN_PA: bool = False VLLM_USE_FLASH_ATTN_PA: bool = False
VLLM_USE_APEX_RN: bool = False VLLM_USE_APEX_RN: bool = False
VLLM_USE_GLOBAL_CACHE13: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1085,6 +1085,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1085,6 +1085,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_APEX_RN": "VLLM_USE_APEX_RN":
lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in lambda: (os.environ.get("VLLM_USE_APEX_RN", "False").lower() in
("true", "1")), ("true", "1")),
# vLLM will use global cache for moe
"VLLM_USE_GLOBAL_CACHE13":
lambda: (os.environ.get("VLLM_USE_GLOBAL_CACHE13", "True").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -44,6 +44,14 @@ from vllm.utils import direct_register_custom_op ...@@ -44,6 +44,14 @@ from vllm.utils import direct_register_custom_op
# from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled # from .rocm_aiter_fused_moe import is_rocm_aiter_moe_enabled
logger = init_logger(__name__) logger = init_logger(__name__)
if envs.VLLM_USE_GLOBAL_CACHE13:
moe_cache_singleton = None
def get_moe_cache(top_k_num,N,K,device,dtype):
global moe_cache_singleton
if moe_cache_singleton is None:
moe_cache_singleton = torch.empty(envs.VLLM_FUSED_MOE_CHUNK_SIZE * top_k_num *max(N, K), device=device, dtype=dtype)
logger.info(f"Initializing moe_cache_singleton shape: {moe_cache_singleton.shape}, memory: {moe_cache_singleton.element_size() * moe_cache_singleton.numel() / 1024**2:.2f} MB")
return moe_cache_singleton
@triton.jit @triton.jit
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token, def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
...@@ -1494,13 +1502,32 @@ def fused_experts_impl( ...@@ -1494,13 +1502,32 @@ def fused_experts_impl(
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
# Check constraints. num_tokens = hidden_states.size(0)
if use_nn_moe:
E, _, N = w1.size()
else:
E, N, _ = w1.size()
K = w2.size(1)
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
if envs.VLLM_USE_GLOBAL_CACHE13:
cache13 = get_moe_cache(top_k_num, N,K if not use_nn_moe else w2.shape[2], device=hidden_states.device, dtype=hidden_states.dtype)
else:
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype)
if use_int8_w8a8 is True: if use_int8_w8a8 is True:
return fused_experts_impl_int8(hidden_states=hidden_states, return fused_experts_impl_int8(hidden_states=hidden_states,
w1=w1, w1=w1,
w2=w2, w2=w2,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
cache13 = cache13,
inplace=inplace, inplace=inplace,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
...@@ -1527,6 +1554,7 @@ def fused_experts_impl( ...@@ -1527,6 +1554,7 @@ def fused_experts_impl(
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=inplace, inplace=inplace,
cache13 = cache13,
activation=activation, activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8= False, use_fp8_w8a8= False,
...@@ -1565,21 +1593,6 @@ def fused_experts_impl( ...@@ -1565,21 +1593,6 @@ def fused_experts_impl(
torch.float32, torch.float16, torch.bfloat16 torch.float32, torch.float16, torch.bfloat16
] ]
num_tokens = hidden_states.size(0)
if use_nn_moe:
E, _, N = w1.size()
else:
E, N, _ = w1.size()
K = w2.size(1)
if global_num_experts == -1:
global_num_experts = E
top_k_num = topk_ids.size(1)
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
M = min(num_tokens, CHUNK_SIZE)
config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
...@@ -1606,9 +1619,6 @@ def fused_experts_impl( ...@@ -1606,9 +1619,6 @@ def fused_experts_impl(
# We can reuse the memory between these because by the time we need # We can reuse the memory between these because by the time we need
# cache3, we're done with cache1 # cache3, we're done with cache1
cache13 = torch.empty(M * top_k_num * max(N, K if not use_nn_moe else w2.shape[2]),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N) intermediate_cache1 = cache13[:M * top_k_num * N].view(M, top_k_num, N)
intermediate_cache3 = cache13[:M * top_k_num * (K if not use_nn_moe else w2.shape[2])].view(M, top_k_num, K if not use_nn_moe else w2.shape[2]) intermediate_cache3 = cache13[:M * top_k_num * (K if not use_nn_moe else w2.shape[2])].view(M, top_k_num, K if not use_nn_moe else w2.shape[2])
......
...@@ -821,7 +821,7 @@ class FusedMoE(torch.nn.Module): ...@@ -821,7 +821,7 @@ class FusedMoE(torch.nn.Module):
if (self.quant_method.__class__.__name__ in ("BlockInt8MoEMethod")): if (self.quant_method.__class__.__name__ in ("BlockInt8MoEMethod")):
moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition
if (self.quant_method.__class__.__name__ in ("W8A8Int8MoEMethod")): if (self.quant_method.__class__.__name__ in ("SlimQuantW4A8Int8MoEMethod")):
moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition
......
...@@ -37,7 +37,7 @@ QuantizationMethods = Literal[ ...@@ -37,7 +37,7 @@ QuantizationMethods = Literal[
"auto-round", "auto-round",
"rtn", "rtn",
"blockwise_int8", "blockwise_int8",
"w8a8_int8" "slimquant_w4a8"
] ]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
...@@ -117,7 +117,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -117,7 +117,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .torchao import TorchAOConfig from .torchao import TorchAOConfig
from .tpu_int8 import Int8TpuConfig from .tpu_int8 import Int8TpuConfig
from .blockwise_int8 import BlockInt8Config from .blockwise_int8 import BlockInt8Config
from .w8a8_int8 import W8A8Int8Config from .slimquant_w4a8 import SlimQuantW4A8Int8Config
method_to_config: dict[str, type[QuantizationConfig]] = { method_to_config: dict[str, type[QuantizationConfig]] = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
...@@ -150,7 +150,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -150,7 +150,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"auto-round": AutoRoundConfig, "auto-round": AutoRoundConfig,
"rtn": RTNConfig, "rtn": RTNConfig,
"blockwise_int8": BlockInt8Config, "blockwise_int8": BlockInt8Config,
"w8a8_int8":W8A8Int8Config, "slimquant_w4a8":SlimQuantW4A8Int8Config,
} }
# Update the `method_to_config` with customized quantization methods. # Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
......
...@@ -1000,7 +1000,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1000,7 +1000,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
raise ValueError( raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, " "For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.") "dynamic per token quantization. Found static input scales.")
self.tritonsingleton= W8a8GetCacheJSON()
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
...@@ -1089,6 +1091,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1089,6 +1091,9 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
activation: str = "silu", activation: str = "silu",
enable_eplb: bool = False, enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
expert_load_view: Optional[torch.Tensor] = None, expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None,
...@@ -1111,6 +1116,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod): ...@@ -1111,6 +1116,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias)
return fused_experts( return fused_experts(
......
...@@ -40,7 +40,7 @@ def baseline_scaled_mm(a: torch.Tensor, ...@@ -40,7 +40,7 @@ def baseline_scaled_mm(a: torch.Tensor,
return output.to(out_dtype) return output.to(out_dtype)
class W8A8Int8Config(QuantizationConfig): class SlimQuantW4A8Int8Config(QuantizationConfig):
"""Config class for W8A8 Int8 Quantization. """Config class for W8A8 Int8 Quantization.
- Weight: static, per-channel, symmetric - Weight: static, per-channel, symmetric
...@@ -60,14 +60,14 @@ class W8A8Int8Config(QuantizationConfig): ...@@ -60,14 +60,14 @@ class W8A8Int8Config(QuantizationConfig):
@classmethod @classmethod
def get_name(self) -> str: def get_name(self) -> str:
return "w8a8_int8" return "slimquant_w4a8"
@classmethod @classmethod
def get_config_filenames(cls) -> List[str]: def get_config_filenames(cls) -> List[str]:
return [] return []
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config": def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8Config":
return cls() return cls()
def get_quant_method( def get_quant_method(
...@@ -77,18 +77,18 @@ class W8A8Int8Config(QuantizationConfig): ...@@ -77,18 +77,18 @@ class W8A8Int8Config(QuantizationConfig):
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self) return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self) return SlimQuantW4A8Int8MoEMethod(self)
return None return None
def get_scaled_act_names(self) -> List[str]: def get_scaled_act_names(self) -> List[str]:
return [] return []
class W8A8Int8LinearMethod(LinearMethodBase): class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
def __init__(self, quantization_config: W8A8Int8Config): def __init__(self, quantization_config: SlimQuantW4A8Int8Config):
self.quantization_config = quantization_config self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON() self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
...@@ -218,8 +218,8 @@ class W8A8Int8LinearMethod(LinearMethodBase): ...@@ -218,8 +218,8 @@ class W8A8Int8LinearMethod(LinearMethodBase):
bias=bias) bias=bias)
class W8A8Int8MoEMethod: class SlimQuantW4A8Int8MoEMethod:
"""MoE method for INT8. """MoE method for W4A8INT8.
Supports loading INT8 checkpoints with static weight scale and Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale. dynamic/static activation scale.
Also supports loading quantized FP16/BF16 model checkpoints with dynamic Also supports loading quantized FP16/BF16 model checkpoints with dynamic
...@@ -354,7 +354,7 @@ class W8A8Int8MoEMethod: ...@@ -354,7 +354,7 @@ class W8A8Int8MoEMethod:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
if enable_eplb: if enable_eplb:
raise NotImplementedError( raise NotImplementedError(
"EPLB not supported for `W8A8Int8MoeMethod` yet.") "EPLB not supported for `SlimQuantW4A8Int8MoEMethod` yet.")
# Expert selection # Expert selection
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
......
...@@ -180,7 +180,7 @@ class RocmPlatform(Platform): ...@@ -180,7 +180,7 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [ supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
"quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","w8a8_int8","awq_marlin" "quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","slimquant_w4a8","awq_marlin"
] ]
@classmethod @classmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment