Commit ad1d74cf authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-for-yuanbao' of...

Merge branch 'v0.9.2-dev-for-yuanbao' of http://10.16.6.30/dcutoolkit/deeplearing/vllm into v0.9.2-dev-for-yuanbao
parents 2b7b1a31 d2feb104
......@@ -208,6 +208,8 @@ if TYPE_CHECKING:
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM:bool = True
VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False
VLLM_USE_MOE_W16A16_TRITON: bool = False
VLLM_USE_FUSED_DTBMM: bool = False
def get_default_cache_root():
return os.getenv(
......@@ -1344,6 +1346,44 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_SHARED_EXPERTS_FUSION": lambda: bool(
int(os.getenv("VLLM_ENABLE_SHARED_EXPERTS_FUSION", "0"))
),
# W8A8 GEMM backend selection for vLLM quantized models.
# lightop/triton: 1
# cutlass: 2 (will remove in the future)
# blaslt: 3 (default)
# rocblas: others
"VLLM_W8A8_BACKEND": lambda: int(os.getenv("VLLM_W8A8_BACKEND", "3")),
# Capture MoE router logits for debugging/analysis.
"VLLM_MOE_ROUTER_CAPTURE":
lambda: (os.getenv("VLLM_MOE_ROUTER_CAPTURE", "0").lower() in ("true", "1")),
# Output directory for MoE router capture dumps.
"VLLM_MOE_ROUTER_CAPTURE_DIR":
lambda: os.environ.get(
"VLLM_MOE_ROUTER_CAPTURE_DIR",
"/tmp",
),
# Capture only the specified rank; set to -1 to capture all ranks.
"VLLM_MOE_ROUTER_CAPTURE_RANK":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_RANK", "-1")),
# Max number of MoE layers to record per process (0 = unlimited).
"VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS", "0")),
# Only capture when num_tokens > N (negative disables).
"VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT", "-1")),
# Only capture when num_tokens < N (0 disables).
"VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT", "-1")),
# Force using Triton MoE path (disable Marlin W16A16 MoE).
"VLLM_USE_MOE_W16A16_TRITON":
lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in
("true", "1")),
# Only quantized DeepSeek models supported.
"VLLM_USE_FUSED_DTBMM":
lambda: (os.environ.get("VLLM_USE_FUSED_DTBMM", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -217,7 +217,9 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from lightop import fused_rms_norm_rope_contiguous
from lightop import fused_rms_norm_rope_contiguous, fuse_rmsnorm_rope_quant_qkv
from lmslim.layers.gemm.fp8_utils import per_token_quant_fp8
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
......@@ -871,13 +873,73 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
return attn_out, lse
return attn_out
def weight_quant_fp8(self, weight, dim:Optional[int]=1):
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_min = finfo.min
fp8_max = finfo.max
absmax = torch.max(weight.abs(), dim=dim, keepdim=True).values
absmax = absmax.clamp(min=1e-10)
scale = absmax.to(torch.float32) / fp8_max
scale = scale.clamp(min=1e-10)
weight_fp32 = weight.float() if weight.dtype != torch.float32 else weight
scale_fp32 = scale.float() if scale.dtype != torch.float32 else scale
weight_q = weight_fp32 / scale_fp32
weight_q = weight_q.clamp(fp8_min, fp8_max)
weight_q = weight_q.to(torch.float8_e4m3fn)
return weight_q, scale
def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
if self.enable_fused_DTBmm():
x = x.view(-1, self.num_heads, self.kv_lora_rank).contiguous()
B, N, L = x.shape
N, V, L = self.weight_uv_bmm.shape
if B <= 32:
from lightop import fused_bmm as fused_DTBmm
from lightop import get_batched_gemm_w8a8_config as DTBmm_config
x = x.reshape(-1, self.num_heads, self.kv_lora_rank).contiguous()
x_q, x_scale = per_token_quant_fp8(x)
x_out = torch.empty(B, N, V, dtype=torch.bfloat16, device=x.device)
_dtype = torch.bfloat16
_config, _status = DTBmm_config(B, N, L)
assert x_q.shape == (B, N, L) , f"assert error {x_q.shape}"
assert x_scale.shape == (B, N, 1) , f"assert error {x_scale.shape}"
fused_DTBmm(x=x_q, w=self.weight_uv_bmm, x_scale=x_scale, w_scale=self.weight_uv_scale_bmm,
bias=None, dtype=_dtype, output=x_out,
transpose_bm=False, transpose_bm_in=False, config=_config)
out = x_out.reshape(-1, self.num_heads * self.v_head_dim)
return out
else:
from lmslim import quant_ops
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
x_= x.reshape(-1,self.kv_lora_rank).contiguous()
x_q, x_scale = per_token_quant_fp8(x_)
x_q = x_q.reshape(self.num_heads,-1,self.kv_lora_rank).contiguous()
x_scale = x_scale.reshape(self.num_heads,-1).contiguous()
weight_k = self.W_UV.shape[1]
weight_n = self.W_UV.shape[2]
_, result = quant_ops.hipblaslt_w8a8_channelwise_gemm(
x_q, self.weight_uv_bmm , x_scale, self.weight_uv_scale_bmm,
x.shape[1], weight_n, weight_k, 'NT', torch.bfloat16, None)
return result.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
else: # default
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
def enable_fused_DTBmm(self):
if envs.VLLM_USE_FUSED_DTBMM and \
torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return True
else:
return False
def process_weights_after_loading(self, act_dtype: torch.dtype):
......@@ -932,6 +994,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0)
if self.enable_fused_DTBmm():
weight_uv_NLV, weight_uv_scale_NL =self.weight_quant_fp8(self.W_UV, 1)
self.weight_uv_bmm = weight_uv_NLV.transpose(1,2).contiguous()
self.weight_uv_scale_bmm = weight_uv_scale_NL.transpose(1,2).contiguous()
def _compute_prefill_context(
self,
......
......@@ -103,7 +103,8 @@ class Sampler(nn.Module):
if sampling_metadata.all_random:
greedy_sampled = None
else:
greedy_sampled = self.greedy_sample(logits)
#greedy_sampled = self.greedy_sample(logits)
greedy_sampled = logits.argmax(dim=-1).view(-1)
if sampling_metadata.all_greedy:
return greedy_sampled
......
......@@ -321,7 +321,7 @@ def bind_kv_cache(
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
length: int) -> torch.Tensor:
length: int, repeat_counts: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Copy the first length elements of a tensor into another tensor in a
non-blocking manner.
......@@ -330,6 +330,11 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
Returns the sliced target tensor.
"""
if repeat_counts is not None:
from_tensor_tmp = torch.repeat_interleave(from_tensor[:length], repeat_counts, dim=0)
length = torch.sum(repeat_counts).item()
from_tensor[:length].copy_(from_tensor_tmp)
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
......
......@@ -9,6 +9,7 @@ import numpy as np
import torch
from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams
......@@ -79,6 +80,10 @@ class InputBatch:
is_spec_decode: bool = False,
logits_processing_needs_token_ids: bool = False,
):
ori_max_num_reqs = max_num_reqs
if is_spec_decode and envs.VLLM_REJECT_SAMPLE_OPT:
vllm_config = get_current_vllm_config()
max_num_reqs = max_num_reqs * (1 + vllm_config.speculative_config.num_speculative_tokens)
self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
......@@ -97,7 +102,7 @@ class InputBatch:
# This buffer is not directly transferred to the GPU, so it does not
# need to be pinned.
self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len),
(ori_max_num_reqs, max_model_len),
device="cpu",
dtype=torch.int32,
pin_memory=False,
......@@ -651,36 +656,44 @@ class InputBatch:
or repeat_counts is not None
or self._sampling_metadata_is_expanded)
if needs_rebuild:
if repeat_counts is None:
self.sampling_metadata = self._make_sampling_metadata()
else:
self.sampling_metadata = self._make_sampling_metadata_expanded(
repeat_counts)
# if repeat_counts is None:
# self.sampling_metadata = self._make_sampling_metadata()
# else:
# self.sampling_metadata = self._make_sampling_metadata_expanded(
# repeat_counts)
self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
self._sampling_metadata_is_expanded = repeat_counts is not None
# Expanded metadata is built on demand; do not cache a copy here.
def _make_sampling_metadata(self) -> SamplingMetadata:
def _make_sampling_metadata(self, repeat_counts: Optional[torch.Tensor] = None) -> SamplingMetadata:
num_reqs = self.num_reqs
if not self.all_greedy:
temperature = copy_slice(self.temperature_cpu_tensor,
self.temperature, num_reqs)
self.temperature, num_reqs,
repeat_counts)
else:
temperature = None
if not self.no_top_p:
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs)
top_p = copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs, repeat_counts)
if not self.no_top_k:
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs)
top_k = copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs, repeat_counts)
frequency_penalties = None
presence_penalties = None
repetition_penalties = None
if not self.no_penalties:
# Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require
# penalties to be applied during sampling.
copy_slice(self.frequency_penalties_cpu_tensor,
self.frequency_penalties, num_reqs)
copy_slice(self.presence_penalties_cpu_tensor,
self.presence_penalties, num_reqs)
copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties, num_reqs)
frequency_penalties = copy_slice(self.frequency_penalties_cpu_tensor,
self.frequency_penalties, num_reqs,
repeat_counts)
presence_penalties = copy_slice(self.presence_penalties_cpu_tensor,
self.presence_penalties, num_reqs,
repeat_counts)
repetition_penalties = copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties, num_reqs,
repeat_counts)
needs_prompt_token_ids = (not self.no_penalties or
(self.num_reqs > 0
......@@ -697,9 +710,9 @@ class InputBatch:
allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None
copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask, num_reqs)
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs]
allowed_token_ids_mask = copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask, num_reqs,
repeat_counts)
# Host-side summaries to avoid device synchronization in sampling
# fast paths (e.g. reduced top-k/top-p sampling).
......@@ -714,14 +727,14 @@ class InputBatch:
temperature=temperature,
all_greedy=self.all_greedy,
all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs],
top_k=None if self.no_top_k else self.top_k[:num_reqs],
top_p=None if self.no_top_p else top_p,
top_k=None if self.no_top_k else top_k,
generators=self.generators,
max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs],
presence_penalties=self.presence_penalties[:num_reqs],
repetition_penalties=self.repetition_penalties[:num_reqs],
frequency_penalties=None if self.no_penalties else frequency_penalties,
presence_penalties=None if self.no_penalties else presence_penalties,
repetition_penalties=None if self.no_penalties else repetition_penalties,
output_token_ids=cast(list[list[int]], self.req_output_token_ids),
no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask,
......
......@@ -587,17 +587,19 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Refresh batch metadata with any pending updates. If we are in spec
# decode + reject mode, also expand sampling metadata to token shape
# using per-request repeat counts.
repeat_counts: Optional[torch.Tensor] = None
repeat_counts = None
if envs.VLLM_REJECT_SAMPLE_OPT and \
scheduler_output.scheduled_spec_decode_tokens:
num_reqs = self.input_batch.num_reqs
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
repeat_counts = [1] * self.input_batch.num_reqs
#num_reqs = self.input_batch.num_reqs
#num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index.get(req_id)
if req_idx is not None:
num_draft_tokens[req_idx] = len(draft_token_ids)
repeat_counts = torch.from_numpy(num_draft_tokens).add_(1)
repeat_counts[req_idx] += len(draft_token_ids)
repeat_counts = torch.tensor(repeat_counts, dtype=torch.int32, device="cpu")
self.input_batch.refresh_metadata(repeat_counts)
def _get_cumsum_and_arange(
......@@ -3437,8 +3439,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
)
sampler_output.sampled_token_ids = output_token_ids
else:
# sampling_metadata.all_greedy = True
# sampling_metadata.all_random = False
sampling_metadata.all_greedy = True
sampling_metadata.all_random = False
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
......
......@@ -644,8 +644,8 @@ class V1ZeroModelRunner(GPUModelRunner):
)
sampler_output.sampled_token_ids = output_token_ids
else:
# sampling_metadata.all_greedy = True
# sampling_metadata.all_random = False
sampling_metadata.all_greedy = True
sampling_metadata.all_random = False
sampler_output = self.sampler(
logits=logits,
sampling_metadata=sampling_metadata,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment