"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "719c1ca468537d2be2616ddc3163236af7f5bd62"
Commit 69f30ae0 authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.9.2-dev' into v0.9.2-dev

parents d04683a4 4a946680
...@@ -19,8 +19,13 @@ from vllm.attention.backends.flash_attn import (FlashAttentionBackend, ...@@ -19,8 +19,13 @@ from vllm.attention.backends.flash_attn import (FlashAttentionBackend,
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import async_tensor_h2d from vllm.utils import async_tensor_h2d
from vllm.vllm_flash_attn import (flash_attn_varlen_func, from vllm.platforms import current_platform
flash_attn_with_kvcache, sparse_attn_func) if not current_platform.is_rocm():
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache, sparse_attn_func)
else:
from flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache, sparse_attn_func)
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder from vllm.worker.model_runner import ModelInputForGPUBuilder
......
...@@ -246,12 +246,33 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -246,12 +246,33 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
device, non_blocking=True) device, non_blocking=True)
else: else:
block_tables = make_tensor_with_pad( has_empty: bool = any(len(bt) == 0 for bt in self.block_tables)
self.block_tables, has_non_empty = any(len(bt) > 0 for bt in self.block_tables)
pad=0, max_block_length = 0
dtype=torch.int, if has_empty and has_non_empty:
device=device, for inter_data in self.input_builder.inter_data_list:
) block_tables = inter_data.block_tables
if block_tables:
for seq_id in inter_data.seq_ids:
if seq_id in block_tables:
block_table = block_tables[seq_id]
max_block_length = max(max_block_length, len(block_table))
if max_block_length >0:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
max_len=max_block_length,
)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, "query_lens: {}".format(query_lens) assert max_query_len > 0, "query_lens: {}".format(query_lens)
assert device is not None assert device is not None
......
...@@ -893,7 +893,8 @@ class ModelConfig: ...@@ -893,7 +893,8 @@ class ModelConfig:
optimized_quantization_methods = [ optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas" "quark", "modelopt_fp4", "bitblas", "gptq_bitblas",
"slimquant_w4a8","slimquant_w4a8_marlin"
] ]
if self.quantization is not None: if self.quantization is not None:
self.quantization = cast(me_quant.QuantizationMethods, self.quantization = cast(me_quant.QuantizationMethods,
...@@ -920,6 +921,7 @@ class ModelConfig: ...@@ -920,6 +921,7 @@ class ModelConfig:
"awq_marlin", "awq_marlin",
"ipex", "ipex",
"moe_wna16", "moe_wna16",
"slimquant_w4a8_marlin"
] ]
quantization_methods = [ quantization_methods = [
q for q in supported_quantization if q not in overrides q for q in supported_quantization if q not in overrides
......
...@@ -1107,8 +1107,8 @@ class EngineArgs: ...@@ -1107,8 +1107,8 @@ class EngineArgs:
"Cuda graph is not supported with DualChunkFlashAttention. " "Cuda graph is not supported with DualChunkFlashAttention. "
"To run the model in eager mode, set 'enforce_eager=True' " "To run the model in eager mode, set 'enforce_eager=True' "
"or use '--enforce-eager' in the CLI.") "or use '--enforce-eager' in the CLI.")
assert current_platform.is_cuda(), ( assert current_platform.is_cuda() or current_platform.is_rocm(), (
"DualChunkFlashAttention is only supported on CUDA platform.") "DualChunkFlashAttention is supported on CUDA/ROCM platform.")
assert not use_v1, ( assert not use_v1, (
"DualChunkFlashAttention is not supported on V1 engine. " "DualChunkFlashAttention is not supported on V1 engine. "
"To run the model in V0 engine, try set 'VLLM_USE_V1=0'") "To run the model in V0 engine, try set 'VLLM_USE_V1=0'")
......
...@@ -811,9 +811,9 @@ class FusedMoE(torch.nn.Module): ...@@ -811,9 +811,9 @@ class FusedMoE(torch.nn.Module):
"CompressedTensorsWNA16MoEMethod")): "CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size moe_quant_params["intermediate_size_full"] = intermediate_size
if (self.quant_method.__class__.__name__ in ("BlockInt8MoEMethod")): if (self.quant_method.__class__.__name__ in ("BlockInt8MoEMethod",
moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition "SlimQuantW4A8Int8MoEMethod",
if (self.quant_method.__class__.__name__ in ("SlimQuantW4A8Int8MoEMethod")): "SlimQuantW4A8Int8MarlinMoEMethod")):
moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition moe_quant_params["intermediate_size"] = self.intermediate_size_per_partition
......
...@@ -37,7 +37,8 @@ QuantizationMethods = Literal[ ...@@ -37,7 +37,8 @@ QuantizationMethods = Literal[
"auto-round", "auto-round",
"rtn", "rtn",
"blockwise_int8", "blockwise_int8",
"slimquant_w4a8" "slimquant_w4a8",
"slimquant_w4a8_marlin"
] ]
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods)) QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
...@@ -118,6 +119,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -118,6 +119,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
from .tpu_int8 import Int8TpuConfig from .tpu_int8 import Int8TpuConfig
from .blockwise_int8 import BlockInt8Config from .blockwise_int8 import BlockInt8Config
from .slimquant_w4a8 import SlimQuantW4A8Int8Config from .slimquant_w4a8 import SlimQuantW4A8Int8Config
from .slimquant_w4a8_marlin import SlimQuantW4A8Int8MarlinConfig
method_to_config: dict[str, type[QuantizationConfig]] = { method_to_config: dict[str, type[QuantizationConfig]] = {
"aqlm": AQLMConfig, "aqlm": AQLMConfig,
...@@ -151,6 +153,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -151,6 +153,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"rtn": RTNConfig, "rtn": RTNConfig,
"blockwise_int8": BlockInt8Config, "blockwise_int8": BlockInt8Config,
"slimquant_w4a8":SlimQuantW4A8Int8Config, "slimquant_w4a8":SlimQuantW4A8Int8Config,
"slimquant_w4a8_marlin":SlimQuantW4A8Int8MarlinConfig,
} }
# Update the `method_to_config` with customized quantization methods. # Update the `method_to_config` with customized quantization methods.
method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG) method_to_config.update(_CUSTOMIZED_METHOD_TO_QUANT_CONFIG)
......
from typing import Any, Callable, Dict, List, Optional
import os
import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed import get_tensor_model_parallel_world_size
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.linear import (LinearBase,LinearMethodBase)
from vllm.model_executor.layers.quantization.base_config import (QuantizationConfig,
QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.w4a8_utils import w4a8_2_marlin_weight
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
ModelWeightParameter)
from vllm.model_executor.layers.quantization.slimquant_w4a8 import SlimQuantW4A8Int8LinearMethod
try:
from lmslim.layers.fused_moe.fuse_moe_w4a8_marlin import fused_experts_impl_w4a8_marlin
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
class MarlinMoeWorkspace:
"""
Singleton manager for device-specific workspace buffers used by w4a8 Marlin-MoE.
global_reduce_buffer will take 1.5MB * cus (about 120MB for BW200) memoery in each device
"""
_instances = {}
def __new__(cls, device):
if device not in cls._instances:
instance = super().__new__(cls)
instance._initialized = False
cls._instances[device] = instance
return cls._instances[device]
def __init__(self, device):
if self._initialized:
return
sms = torch.cuda.get_device_properties(device).multi_processor_count
self.workspace = torch.zeros(
500, dtype=torch.int, device=device, requires_grad=False
)
self.global_reduce_buffer = torch.zeros(
sms * 6 * 128 * 512, dtype=torch.int, device=device, requires_grad=False
)
self._initialized = True
def get_buffers(self):
return self.workspace, self.global_reduce_buffer
def baseline_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
scales= scale_a* scale_b.T
gemmout= torch.mm(
a.to(dtype=torch.float32), b.to(dtype=torch.float32))
output = (scales *gemmout).to(out_dtype)
if bias is not None:
output = output + bias
return output.to(out_dtype)
class SlimQuantW4A8Int8MarlinConfig(QuantizationConfig):
"""Config class for W4A8 Int8 Quantization.
- Weight: static, per-channel, symmetric
- Activation: dynamic, per-token, symmetric
"""
def __init__(self):
pass
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 75
@classmethod
def get_name(self) -> str:
return "slimquant_w4a8_marlin"
@classmethod
def get_config_filenames(cls) -> List[str]:
return []
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SlimQuantW4A8Int8MarlinConfig":
return cls()
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
if hf_quant_cfg.get("quant_method") == "slimquant_w4a8" \
and user_quant == "slimquant_w4a8_marlin":
return cls.get_name()
return None
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase):
return SlimQuantW4A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return SlimQuantW4A8Int8MarlinMoEMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class SlimQuantW4A8Int8MarlinMoEMethod:
"""MoE method for W4A8INT8 Marlin.
Supports loading INT8 checkpoints with static weight scale and
dynamic/static activation scale.
Args:
quant_config: The quantization config.
"""
def __new__(cls, *args, **kwargs):
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
new_cls = type(
cls.__name__,
(FusedMoEMethodBase,),
{
"__init__": original_init,
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
},
)
obj = super(new_cls, new_cls).__new__(new_cls)
obj.__init__(*args, **kwargs)
return obj
return super().__new__(cls)
def __init__(self, quant_config):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
tp_size = get_tensor_model_parallel_world_size()
# WEIGHTS
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size//2, dtype=torch.int8
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size//2, dtype=torch.int8),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
w13_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
w2_weight_scale = torch.nn.Parameter(
torch.ones(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
w13_input_scale = None
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = None
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data, requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data, requires_grad=False
)
w1_marlin_list = []
for e in range(layer.w13_weight.shape[0]):
w1_marlin_in = w4a8_2_marlin_weight(layer.w13_weight[e])
w1_marlin_list.append(w1_marlin_in)
layer.w13_weight = Parameter(torch.stack(w1_marlin_list, dim=0), requires_grad=False)
w2_marlin_list = []
for e in range(layer.w2_weight.shape[0]):
w2_marlin_in = w4a8_2_marlin_weight(layer.w2_weight[e])
w2_marlin_list.append(w2_marlin_in)
layer.w2_weight = Parameter(torch.stack(w2_marlin_list, dim=0), requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
**_
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
if enable_eplb:
raise NotImplementedError(
"EPLB not supported for `SlimQuantW4A8Int8MarlinMoEMethod` yet.")
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
)
workspace, global_reduce_buffer = MarlinMoeWorkspace(x.device).get_buffers()
return fused_experts_impl_w4a8_marlin(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
workspace=workspace,
global_reduce_buffer=global_reduce_buffer,
inplace=True,
use_int4_w4a8=True,
per_channel_quant=True,
activation=activation,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
)
import torch
import numpy as np
def unpack_int8_to_int4(tensor_int8: torch.Tensor) -> torch.Tensor:
"""
将[N, K//2]大小的torch.int8 Tensor,转换为[N, K]大小的torch.int32 Tensor。
每个int8包含两个int4,分别提取到int32的低4位,其余位为0。
Args:
tensor_int8 (torch.Tensor): 输入张量,形状为[N, K//2],类型为torch.int8。
Returns:
torch.Tensor: 输出张量,形状为[N, K],类型为torch.int32。
"""
if tensor_int8.dtype != torch.int8:
raise ValueError("Input tensor must be of type torch.int8")
N, K_half = tensor_int8.shape
tensor_uint8 = tensor_int8.to(torch.uint8)
high4 = tensor_uint8 & 0x0F
low4 = (tensor_uint8 >> 4) & 0x0F
unpacked = torch.empty((N, K_half * 2), dtype=torch.int32, device=tensor_int8.device)
unpacked[:, 0::2] = low4.to(torch.int32)
unpacked[:, 1::2] = high4.to(torch.int32)
return unpacked
def get_weight_perms(interleave: bool=True):
perm = []
for i in range(64):
for col in range(4):
cur_col = (i % 16) * 4 + col
for row in range(8):
cur_row = (i // 16) * 8 + row
cur_idx = cur_row * 64 + cur_col
perm.append(cur_idx)
perm = np.array(perm)
if interleave:
interleave = np.array([4, 0, 5, 1, 6, 2, 7, 3])
perm = perm.reshape((-1, 8))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm
def marlin_weights(q_w,weight_perm,k_tile=32,n_tile=64,pack_factor=8):
size_k, size_n = q_w.shape
q_w = q_w.reshape((size_k // k_tile, k_tile, size_n // n_tile, n_tile))
q_w = q_w.permute((0, 2, 1, 3))
q_w = q_w.reshape((size_k // k_tile, size_n * k_tile))
q_w = q_w.reshape((-1, weight_perm.numel()))[:, weight_perm].reshape(q_w.shape)
orig_device = q_w.device
q_w = q_w.cpu().numpy().astype(np.uint32)
q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), dtype=np.uint32)
for i in range(pack_factor):
q_packed |= q_w[:, i::pack_factor] << 4 * i
q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device)
return q_packed
def w4a8_2_marlin_weight(w4a8_w):
full_w4a8_w = unpack_int8_to_int4(w4a8_w)
full_w4a8_w = full_w4a8_w.T
weight_perm = get_weight_perms()
marlin_q_w = marlin_weights(full_w4a8_w, weight_perm, k_tile=32, n_tile=64, pack_factor=8)
return marlin_q_w
...@@ -67,7 +67,7 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter, ...@@ -67,7 +67,7 @@ from .utils import (PPMissingLayer, is_pp_missing_parameter,
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
os.environ['DPSK_FP16_QUICK'] = os.environ.get('DPSK_FP16_QUICK', '1') os.environ['DPSK_FP16_QUICK'] = os.environ.get('DPSK_FP16_QUICK', '0')
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
def __init__( def __init__(
...@@ -622,9 +622,13 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -622,9 +622,13 @@ class DeepseekV2DecoderLayer(nn.Module):
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
# Fix residual FP16 overflow
residual_fix_overflow = False
if residual is None: if residual is None:
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
residual_fix_overflow = True
else: else:
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(
hidden_states, residual) hidden_states, residual)
...@@ -640,7 +644,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -640,7 +644,7 @@ class DeepseekV2DecoderLayer(nn.Module):
# We scale both hidden_states and residual before # We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale. # rmsnorm, and rmsnorm result would not affect by scale.
hidden_states *= 1. / self.routed_scaling_factor hidden_states *= 1. / self.routed_scaling_factor
if self.layer_idx == 0: if self.layer_idx == 0 or residual_fix_overflow:
# The residual is shared by all layers, we only scale it on # The residual is shared by all layers, we only scale it on
# first layer. # first layer.
residual *= 1. / self.routed_scaling_factor residual *= 1. / self.routed_scaling_factor
...@@ -778,14 +782,17 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts): ...@@ -778,14 +782,17 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts):
self.num_expert_groups = config.n_group self.num_expert_groups = config.n_group
self.moe_layers: list[FusedMoE] = [] self.moe_layers: list[FusedMoE] = []
example_moe = None
for layer in self.model.layers: for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, DeepseekV2DecoderLayer) assert isinstance(layer, DeepseekV2DecoderLayer)
if isinstance(layer.mlp, DeepseekV2MoE): if isinstance(layer.mlp, DeepseekV2MoE):
example_moe = layer.mlp
self.moe_layers.append(layer.mlp.experts) self.moe_layers.append(layer.mlp.experts)
# Pick last one layer since the first ones may be dense layers. # Pick last one layer since the first ones may be dense layers.
example_moe = typing.cast(
DeepseekV2MoE, self.model.layers[config.num_hidden_layers - 1].mlp)
self.num_logical_experts = example_moe.n_logical_experts self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts self.num_local_physical_experts = example_moe.n_local_physical_experts
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import json
import time
from contextlib import nullcontext
from datetime import datetime
from itertools import product
from typing import Any, TypedDict, Optional
import ray
import torch
from ray.experimental.tqdm_ray import tqdm
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser
# 移除全局的 current_platform 导入,改为在需要时局部导入
# FP8_DTYPE = current_platform.fp8_dtype()
class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int
num_ldmatrixes: Optional[int]
def benchmark_config(
config: BenchmarkConfig,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
block_quant_shape: list[int] = None,
use_deep_gemm: bool = False,
nn_moe: Optional[bool] = False
) -> float:
from vllm.platforms import current_platform
device = torch.cuda.current_device()
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
if use_int8_w8a16:
if not nn_moe:
w1 = torch.randint(
-127,
127,
(
num_experts,
shard_intermediate_size,
hidden_size,
),
dtype=torch.int8,
device=device,
)
w2 = torch.randint(
-127,
127,
(
num_experts,
hidden_size,
shard_intermediate_size // 2,
),
dtype=torch.int8,
device=device,
)
else:
w1 = torch.randint(
-127,
127,
(
num_experts,
hidden_size,
shard_intermediate_size,
),
dtype=torch.int8,
device=device,
)
w2 = torch.randint(
-127,
127,
(
num_experts,
shard_intermediate_size // 2,
hidden_size,
),
dtype=torch.int8,
device=device,
)
else:
if not nn_moe:
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype, device=device
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype, device=device
)
else:
w1 = torch.randn(
num_experts, hidden_size, shard_intermediate_size, dtype=init_dtype, device=device
)
w2 = torch.randn(
num_experts, shard_intermediate_size // 2, hidden_size, dtype=init_dtype, device=device
)
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32, device=device)
w1_scale = None
w2_scale = None
a1_scale = None
a2_scale = None
if use_int8_w8a16:
w1_scale = torch.randn(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32, device=device
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32, device=device)
if use_fp8_w8a8:
if block_quant_shape:
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
E = num_experts
N = shard_intermediate_size // 2
K = hidden_size
factor_for_scale = 1e-2
n_tiles_w1 = (2 * N + block_n - 1) // block_n
n_tiles_w2 = (K + block_n - 1) // block_n
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_scale = (
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device=device)
* factor_for_scale
)
w2_scale = (
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device=device)
* factor_for_scale
)
else:
w1_scale = torch.randn(num_experts, dtype=torch.float32, device=device)
w2_scale = torch.randn(num_experts, dtype=torch.float32, device=device)
a1_scale = torch.randn(1, dtype=torch.float32, device=device)
a2_scale = torch.randn(1, dtype=torch.float32, device=device)
# 获取 FP8_DTYPE
FP8_DTYPE = current_platform.fp8_dtype()
w1 = w1.to(FP8_DTYPE)
w2 = w2.to(FP8_DTYPE)
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32, device=device)
def prepare(i: int):
input_gating.copy_(gating_output[i])
def run():
from vllm.model_executor.layers.fused_moe import override_config
with override_config(config):
if use_deep_gemm:
topk_weights, topk_ids, token_expert_indices = fused_topk(
x, input_gating, topk, False
)
return fused_experts(
x,
w1,
w2,
topk_weights,
topk_ids,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
allow_deep_gemm=True,
use_nn_moe=nn_moe,
)
else:
fused_moe(
x,
w1,
w2,
input_gating,
topk,
renormalize=True,
inplace=True,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_quant_shape,
use_nn_moe=nn_moe,
)
# JIT compilation & warmup
run()
torch.cuda.synchronize()
# Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
for _ in range(10):
run()
torch.cuda.synchronize()
# Warmup
for _ in range(5):
graph.replay()
# run()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
prepare(i)
torch.cuda.synchronize()
start_event.record()
graph.replay()
# run()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
graph.reset()
return avg
def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
block_m_range = [16, 32, 64, 128, 256]
block_n_range = [32, 64, 128, 256]
block_k_range = [32, 64, 128, 256]
if not use_fp16:
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
num_warps_range = [2, 4, 8]
group_m_range = [1, 16, 32, 64]
num_stage_range = [2, 3, 4, 5]
# waves_per_eu_range = [0]
# matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
# kpack_range = [1, 2] if use_fp16 else []
param_ranges = {
"BLOCK_SIZE_M": block_m_range,
"BLOCK_SIZE_N": block_n_range,
"BLOCK_SIZE_K": block_k_range,
"GROUP_SIZE_M": group_m_range,
"num_warps": num_warps_range,
"num_stages": num_stage_range,
# "waves_per_eu": waves_per_eu_range,
}
if nn_moe:
param_ranges["num_ldmatrixes"] = [1]
# DCU currently does not support the following parameters
# if use_fp16:
# param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
# param_ranges["kpack"] = kpack_range
return param_ranges
def get_configs_compute_bound(use_fp16, block_quant_shape, nn_moe: Optional[bool] = False) -> list[dict[str, int]]:
configs: list[BenchmarkConfig] = []
# 局部导入 current_platform
from vllm.platforms import current_platform
if current_platform.is_rocm():
param_ranges = get_rocm_tuning_space(use_fp16, nn_moe)
else:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
block_m_range = [16, 32, 64, 128, 256]
block_n_range = [32, 64, 128, 256]
block_k_range = [64, 128, 256]
num_warps_range = [4, 8]
group_m_range = [1, 16, 32, 64]
num_stage_range = [2, 3, 4, 5]
param_ranges = {
"BLOCK_SIZE_M": block_m_range,
"BLOCK_SIZE_N": block_n_range,
"BLOCK_SIZE_K": block_k_range,
"GROUP_SIZE_M": group_m_range,
"num_warps": num_warps_range,
"num_stages": num_stage_range,
}
keys, values = zip(*param_ranges.items())
for config_values in product(*values):
config = dict(zip(keys, config_values))
configs.append(config)
# Remove configs that are not compatible with fp8 block quantization
# BLOCK_SIZE_K must be a multiple of block_k
# BLOCK_SIZE_N must be a multiple of block_n
if block_quant_shape is not None and not use_fp16:
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
for config in configs[:]:
if (
config["BLOCK_SIZE_K"] % block_k != 0
or config["BLOCK_SIZE_N"] % block_n != 0
):
configs.remove(config)
return configs
def prune_rocm_search_space(
num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk
):
N1, K1 = shard_intermediate_size, hidden_size
N2, K2 = hidden_size, shard_intermediate_size // 2
pruned_space_1 = prune_rocm_configs(
num_tokens * topk, N1, K1, search_space, is_fp16
)
pruned_space_2 = prune_rocm_configs(
num_tokens * topk, N2, K2, search_space, is_fp16
)
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
return search_space
# The following code is inspired by ROCm/Triton GEMM tuning script:
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
pruned_configs = []
elemBytes_a = 2 if is_fp16 else 1
elemBytes_b = 2 if is_fp16 else 1
mfma = 16 if M < 32 or N < 32 else 32
# TODO (zhanglx): figure out the boundary between large and small gemms
large_gemm = False
if M >= 2048 and N >= 2048:
large_gemm = True
for config in configs:
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
num_warps = config.get("num_warps")
# DCU currently does not support matrix_instr_nonkdim param
# if is_fp16:
# matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
# if matrix_instr_nonkdim > mfma:
# continue
if mfma == 4 and BLOCK_SIZE_K < 64:
continue
# some layouts could not work properly in case
# number elements per thread is less 1
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
continue
SPLIT_K = config.get("SPLIT_K", 1)
GROUP_M = config.get("GROUP_SIZE_M")
# DCU currently does not support matrix_instr_nonkdim param
# if is_fp16:
# if (
# matrix_instr_nonkdim > BLOCK_SIZE_M
# or matrix_instr_nonkdim > BLOCK_SIZE_N
# ):
# continue
# if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M:
# continue
# if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N:
# continue
# Skip BLOCK_SIZE that is too large compare to M/N
# unless BLOCK_SIZE is already small enough
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
continue
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
continue
# skip large split_k when not necessary
if SPLIT_K != 1 and not need_split_k(M, N, K):
continue
# skip split_k that leads to EVEN_K = false
leap = SPLIT_K * BLOCK_SIZE_K
modv = K % leap
if modv != 0:
continue
# skip large GROUP_M
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
continue
# out of shared memory resource
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
LDS = (
BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a
+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b
)
if LDS > 65536:
continue
# Skip small block sizes and num_warps for large gemm
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
if large_gemm:
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
continue
if BLOCK_SIZE_K < 64:
continue
if num_warps < 4:
continue
pruned_configs.append(config)
return pruned_configs
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
def merge_unique_dicts(list1, list2):
result = []
combined_list = list1.copy()
combined_list.extend(list2)
for dictionary in combined_list:
if dictionary not in result:
result.append(dictionary)
return result
@ray.remote(num_gpus=1)
class BenchmarkWorker:
def __init__(self, seed: int, device_id: int) -> None:
from vllm.platforms import current_platform
import os
if current_platform.is_rocm():
# In ROCm environment with Ray, let Ray handle device assignment
# Don't manually set default device as it may conflict with Ray's device mapping
pass
else:
torch.set_default_device("cuda:"+ str(device_id))
current_platform.seed_everything(seed)
self.seed = seed
# Store the logical device ID for Ray
self.device_id = device_id
def benchmark(
self,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
block_quant_shape: list[int] = None,
use_deep_gemm: bool = False,
nn_moe: Optional[bool] = False,
) -> tuple[dict[str, int], float]:
# 局部导入 current_platform
from vllm.platforms import current_platform
current_platform.seed_everything(self.seed)
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, get_moe_configs, get_default_config
)
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
op_config = get_moe_configs(
num_experts, shard_intermediate_size // 2, dtype_str, use_nn_moe=nn_moe
)
if op_config is None:
config = get_default_config(
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype_str,
is_marlin=False,
use_nn_moe=nn_moe
)
else:
config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))]
kernel_time = benchmark_config(
config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=100,
block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm,
use_nn_moe=nn_moe
)
return config, kernel_time
def tune(
self,
num_tokens: int,
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
search_space: list[dict[str, int]],
block_quant_shape: list[int],
use_deep_gemm: bool,
nn_moe: Optional[bool] = False,
) -> dict[str, int]:
from vllm.platforms import current_platform
import os
best_config = None
best_time = float("inf")
if current_platform.is_rocm():
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
search_space = prune_rocm_search_space(
num_tokens,
shard_intermediate_size,
hidden_size,
search_space,
is_fp16,
topk,
)
# In ROCm environments with Ray, device context is already handled by Ray
# Using torch.cuda.device() may cause device ordinal conflicts
need_device_guard = False
if current_platform.is_rocm():
# For ROCm with Ray, skip additional device context management
need_device_guard = False
else:
# For other platforms, use device guard if needed
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if visible_devices is not None and len(visible_devices.split(',')) > 1:
need_device_guard = True
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(
config,
num_tokens,
num_experts,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=20,
block_quant_shape=block_quant_shape,
use_deep_gemm=use_deep_gemm,
nn_moe=nn_moe)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
if kernel_time < best_time:
best_time = kernel_time
best_config = config
now = datetime.now()
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
assert best_config is not None
return best_config
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
"num_warps": config["num_warps"],
"num_stages": config["num_stages"],
**(
{"num_ldmatrixes": config["num_ldmatrixes"]} if "num_ldmatrixes" in config else {}
),
**(
{"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {}
),
**(
{"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]}
if "matrix_instr_nonkdim" in config
else {}
),
**({"kpack": config["kpack"]} if "kpack" in config else {}),
}
def save_configs(
configs: dict[int, BenchmarkConfig],
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
block_quant_shape: list[int],
use_nn_moe: Optional[bool] = False,
) -> None:
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, get_config_file_name
)
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(
num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape, use_nn_moe=use_nn_moe
)
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
f.write("\n")
def get_weight_block_size_safety(config, default_value=None):
quantization_config = getattr(config, "quantization_config", {})
if isinstance(quantization_config, dict):
return quantization_config.get("weight_block_size", default_value)
return default_value
def main(args: argparse.Namespace):
import os
import logging
from vllm.platforms import current_platform
logger = logging.getLogger(__name__)
print(args)
tp_size = args.tp_size
config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
if args.model_prefix:
config = getattr(config, args.model_prefix)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM", "Glm4MoeForCausalLM"):
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"):
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ("Step3VLForConditionalGeneration"):
E = config.text_config.moe_num_experts
topk = config.text_config.moe_top_k
intermediate_size = config.text_config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
else:
# Support for llama4
config = config.get_text_config()
# Default: Mixtral.
E = config.num_local_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
block_quant_shape = get_weight_block_size_safety(config)
if args.batch_size is None:
batch_sizes = [
1,
2,
4,
8,
16,
24,
32,
48,
64,
96,
128,
256,
512,
1024,
1536,
2048,
3072,
4096,
]
else:
batch_sizes = args.batch_size
use_deep_gemm = bool(args.use_deep_gemm)
if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ:
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
logger.warning(
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES."
)
val = os.environ["HIP_VISIBLE_DEVICES"]
os.environ["ROCR_VISIBLE_DEVICES"] = val
del os.environ["HIP_VISIBLE_DEVICES"]
ray.init(address=None, ignore_reinit_error=True, num_gpus=args.num_gpus)
num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed, i) for i in range(num_gpus)]
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
outputs = []
worker_idx = 0
for input_args in inputs:
worker = workers[worker_idx]
worker_method = getattr(worker, method)
output = worker_method.remote(*input_args)
outputs.append(output)
worker_idx = (worker_idx + 1) % num_gpus
return ray.get(outputs)
if args.tune:
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
search_space = get_configs_compute_bound(is_fp16, block_quant_shape, args.nn_moe)
print(f"Start tuning over {len(search_space)} configurations...")
start = time.time()
configs = _distribute(
"tune",
[
(
batch_size,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
search_space,
block_quant_shape,
use_deep_gemm,
args.nn_moe,
)
for batch_size in batch_sizes
],
)
best_configs = {
M: sort_config(config) for M, config in zip(batch_sizes, configs)
}
save_configs(
best_configs,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
block_quant_shape,
use_nn_moe=args.nn_moe,
)
end = time.time()
print(f"Tuning took {end - start:.2f} seconds")
else:
outputs = _distribute(
"benchmark",
[
(
batch_size,
E,
shard_intermediate_size,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
block_quant_shape,
use_deep_gemm,
args.nn_moe,
)
for batch_size in batch_sizes
],
)
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}, config: {config}")
print(f"Kernel time: {kernel_time:.2f} us")
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument(
"--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1"
)
parser.add_argument(
"--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2
)
parser.add_argument(
"--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto"
)
parser.add_argument("--use-deep-gemm", action="store_true")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, nargs="+", required=False)
parser.add_argument("--tune", action="store_true")
parser.add_argument("--nn-moe", action='store_true', default=False)
parser.add_argument("--trust-remote-code", action="store_true")
parser.add_argument("--model-prefix", type=str, required=False)
parser.add_argument("--num-gpus", type=int, default=1)
args = parser.parse_args()
main(args)
...@@ -180,7 +180,7 @@ class RocmPlatform(Platform): ...@@ -180,7 +180,7 @@ class RocmPlatform(Platform):
supported_quantization: list[str] = [ supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf", "awq", "gptq", "fp8", "compressed-tensors", "fbgemm_fp8", "gguf",
"quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","slimquant_w4a8","awq_marlin" "quark", "ptpc_fp8", "moe_wna16", "blockwise_int8","slimquant_w4a8","awq_marlin","slimquant_w4a8_marlin"
] ]
@classmethod @classmethod
...@@ -282,6 +282,10 @@ class RocmPlatform(Platform): ...@@ -282,6 +282,10 @@ class RocmPlatform(Platform):
logger.info_once("Using Triton backend on V1 engine.") logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1 return TRITON_ATTN_VLLM_V1
if selected_backend == _Backend.DUAL_CHUNK_FLASH_ATTN:
logger.info("Using DualChunkFlashAttention backend.")
return ("vllm.attention.backends.dual_chunk_flash_attn."
"DualChunkFlashAttentionBackend")
if selected_backend == _Backend.ROCM_FLASH: if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90): if not cls.has_device_capability(90):
# not Instinct series GPUs. # not Instinct series GPUs.
......
...@@ -177,11 +177,11 @@ def zero_overhead_update_from_output(scheduler:Scheduler, ...@@ -177,11 +177,11 @@ def zero_overhead_update_from_output(scheduler:Scheduler,
# loop can be a performance bottleneck. We should do our best to avoid # loop can be a performance bottleneck. We should do our best to avoid
# expensive operations inside the loop. # expensive operations inside the loop.
for request in scheduler.running: for request in scheduler.running:
req_id = request.request_id
if request.is_finished(): if request.is_finished():
if req_id in requsets_valid_token_len: if req_id in requsets_valid_token_len:
requsets_valid_token_len.pop(req_id) requsets_valid_token_len.pop(req_id)
continue continue
req_id = request.request_id
num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0) num_tokens_scheduled = num_scheduled_tokens.get(req_id, 0)
if num_tokens_scheduled == 0: if num_tokens_scheduled == 0:
# The request was not scheduled in this step. # The request was not scheduled in this step.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment