Commit ad1d74cf authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-for-yuanbao' of...

Merge branch 'v0.9.2-dev-for-yuanbao' of http://10.16.6.30/dcutoolkit/deeplearing/vllm into v0.9.2-dev-for-yuanbao
parents 2b7b1a31 d2feb104
...@@ -208,6 +208,8 @@ if TYPE_CHECKING: ...@@ -208,6 +208,8 @@ if TYPE_CHECKING:
VLLM_V1_FAST_TOKEN_ID_COPY: bool = False VLLM_V1_FAST_TOKEN_ID_COPY: bool = False
VLLM_DISABLE_SHARED_EXPERTS_STREAM:bool = True VLLM_DISABLE_SHARED_EXPERTS_STREAM:bool = True
VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False VLLM_ENABLE_SHARED_EXPERTS_FUSION: bool = False
VLLM_USE_MOE_W16A16_TRITON: bool = False
VLLM_USE_FUSED_DTBMM: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1344,6 +1346,44 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1344,6 +1346,44 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_SHARED_EXPERTS_FUSION": lambda: bool( "VLLM_ENABLE_SHARED_EXPERTS_FUSION": lambda: bool(
int(os.getenv("VLLM_ENABLE_SHARED_EXPERTS_FUSION", "0")) int(os.getenv("VLLM_ENABLE_SHARED_EXPERTS_FUSION", "0"))
), ),
# W8A8 GEMM backend selection for vLLM quantized models.
# lightop/triton: 1
# cutlass: 2 (will remove in the future)
# blaslt: 3 (default)
# rocblas: others
"VLLM_W8A8_BACKEND": lambda: int(os.getenv("VLLM_W8A8_BACKEND", "3")),
# Capture MoE router logits for debugging/analysis.
"VLLM_MOE_ROUTER_CAPTURE":
lambda: (os.getenv("VLLM_MOE_ROUTER_CAPTURE", "0").lower() in ("true", "1")),
# Output directory for MoE router capture dumps.
"VLLM_MOE_ROUTER_CAPTURE_DIR":
lambda: os.environ.get(
"VLLM_MOE_ROUTER_CAPTURE_DIR",
"/tmp",
),
# Capture only the specified rank; set to -1 to capture all ranks.
"VLLM_MOE_ROUTER_CAPTURE_RANK":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_RANK", "-1")),
# Max number of MoE layers to record per process (0 = unlimited).
"VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_MAX_LAYERS", "0")),
# Only capture when num_tokens > N (negative disables).
"VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_GT", "-1")),
# Only capture when num_tokens < N (0 disables).
"VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT":
lambda: int(os.environ.get("VLLM_MOE_ROUTER_CAPTURE_NUM_TOKENS_LT", "-1")),
# Force using Triton MoE path (disable Marlin W16A16 MoE).
"VLLM_USE_MOE_W16A16_TRITON":
lambda: (os.environ.get("VLLM_USE_MOE_W16A16_TRITON", "0").lower() in
("true", "1")),
# Only quantized DeepSeek models supported.
"VLLM_USE_FUSED_DTBMM":
lambda: (os.environ.get("VLLM_USE_FUSED_DTBMM", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -217,7 +217,9 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, ...@@ -217,7 +217,9 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from lightop import fused_rms_norm_rope_contiguous from lightop import fused_rms_norm_rope_contiguous, fuse_rmsnorm_rope_quant_qkv
from lmslim.layers.gemm.fp8_utils import per_token_quant_fp8
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
...@@ -871,13 +873,73 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -871,13 +873,73 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
return attn_out, lse return attn_out, lse
return attn_out return attn_out
def weight_quant_fp8(self, weight, dim:Optional[int]=1):
finfo = torch.finfo(torch.float8_e4m3fn)
fp8_min = finfo.min
fp8_max = finfo.max
absmax = torch.max(weight.abs(), dim=dim, keepdim=True).values
absmax = absmax.clamp(min=1e-10)
scale = absmax.to(torch.float32) / fp8_max
scale = scale.clamp(min=1e-10)
weight_fp32 = weight.float() if weight.dtype != torch.float32 else weight
scale_fp32 = scale.float() if scale.dtype != torch.float32 else scale
weight_q = weight_fp32 / scale_fp32
weight_q = weight_q.clamp(fp8_min, fp8_max)
weight_q = weight_q.to(torch.float8_e4m3fn)
return weight_q, scale
def _v_up_proj(self, x): def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L) if self.enable_fused_DTBmm():
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) x = x.view(-1, self.num_heads, self.kv_lora_rank).contiguous()
# Multiply (N, B, L) x (N, L, V) -> (N, B, V) B, N, L = x.shape
x = torch.bmm(x, self.W_UV) N, V, L = self.weight_uv_bmm.shape
# Convert from (N, B, V) to (B, N * V) if B <= 32:
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) from lightop import fused_bmm as fused_DTBmm
from lightop import get_batched_gemm_w8a8_config as DTBmm_config
x = x.reshape(-1, self.num_heads, self.kv_lora_rank).contiguous()
x_q, x_scale = per_token_quant_fp8(x)
x_out = torch.empty(B, N, V, dtype=torch.bfloat16, device=x.device)
_dtype = torch.bfloat16
_config, _status = DTBmm_config(B, N, L)
assert x_q.shape == (B, N, L) , f"assert error {x_q.shape}"
assert x_scale.shape == (B, N, 1) , f"assert error {x_scale.shape}"
fused_DTBmm(x=x_q, w=self.weight_uv_bmm, x_scale=x_scale, w_scale=self.weight_uv_scale_bmm,
bias=None, dtype=_dtype, output=x_out,
transpose_bm=False, transpose_bm_in=False, config=_config)
out = x_out.reshape(-1, self.num_heads * self.v_head_dim)
return out
else:
from lmslim import quant_ops
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
x_= x.reshape(-1,self.kv_lora_rank).contiguous()
x_q, x_scale = per_token_quant_fp8(x_)
x_q = x_q.reshape(self.num_heads,-1,self.kv_lora_rank).contiguous()
x_scale = x_scale.reshape(self.num_heads,-1).contiguous()
weight_k = self.W_UV.shape[1]
weight_n = self.W_UV.shape[2]
_, result = quant_ops.hipblaslt_w8a8_channelwise_gemm(
x_q, self.weight_uv_bmm , x_scale, self.weight_uv_scale_bmm,
x.shape[1], weight_n, weight_k, 'NT', torch.bfloat16, None)
return result.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
else: # default
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
def enable_fused_DTBmm(self):
if envs.VLLM_USE_FUSED_DTBMM and \
torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938":
return True
else:
return False
def process_weights_after_loading(self, act_dtype: torch.dtype): def process_weights_after_loading(self, act_dtype: torch.dtype):
...@@ -932,6 +994,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -932,6 +994,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.W_UV = W_UV.transpose(0, 1) self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L) # Convert from (L, N, P) to (N, P, L)
self.W_UK_T = W_UK.permute(1, 2, 0) self.W_UK_T = W_UK.permute(1, 2, 0)
if self.enable_fused_DTBmm():
weight_uv_NLV, weight_uv_scale_NL =self.weight_quant_fp8(self.W_UV, 1)
self.weight_uv_bmm = weight_uv_NLV.transpose(1,2).contiguous()
self.weight_uv_scale_bmm = weight_uv_scale_NL.transpose(1,2).contiguous()
def _compute_prefill_context( def _compute_prefill_context(
self, self,
......
...@@ -103,7 +103,8 @@ class Sampler(nn.Module): ...@@ -103,7 +103,8 @@ class Sampler(nn.Module):
if sampling_metadata.all_random: if sampling_metadata.all_random:
greedy_sampled = None greedy_sampled = None
else: else:
greedy_sampled = self.greedy_sample(logits) #greedy_sampled = self.greedy_sample(logits)
greedy_sampled = logits.argmax(dim=-1).view(-1)
if sampling_metadata.all_greedy: if sampling_metadata.all_greedy:
return greedy_sampled return greedy_sampled
......
...@@ -321,7 +321,7 @@ def bind_kv_cache( ...@@ -321,7 +321,7 @@ def bind_kv_cache(
def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
length: int) -> torch.Tensor: length: int, repeat_counts: Optional[torch.Tensor] = None) -> torch.Tensor:
""" """
Copy the first length elements of a tensor into another tensor in a Copy the first length elements of a tensor into another tensor in a
non-blocking manner. non-blocking manner.
...@@ -330,6 +330,11 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor, ...@@ -330,6 +330,11 @@ def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
Returns the sliced target tensor. Returns the sliced target tensor.
""" """
if repeat_counts is not None:
from_tensor_tmp = torch.repeat_interleave(from_tensor[:length], repeat_counts, dim=0)
length = torch.sum(repeat_counts).item()
from_tensor[:length].copy_(from_tensor_tmp)
return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True) return to_tensor[:length].copy_(from_tensor[:length], non_blocking=True)
......
...@@ -9,6 +9,7 @@ import numpy as np ...@@ -9,6 +9,7 @@ import numpy as np
import torch import torch
from vllm import envs from vllm import envs
from vllm.config import get_current_vllm_config
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.pooling_params import PoolingParams from vllm.pooling_params import PoolingParams
...@@ -79,6 +80,10 @@ class InputBatch: ...@@ -79,6 +80,10 @@ class InputBatch:
is_spec_decode: bool = False, is_spec_decode: bool = False,
logits_processing_needs_token_ids: bool = False, logits_processing_needs_token_ids: bool = False,
): ):
ori_max_num_reqs = max_num_reqs
if is_spec_decode and envs.VLLM_REJECT_SAMPLE_OPT:
vllm_config = get_current_vllm_config()
max_num_reqs = max_num_reqs * (1 + vllm_config.speculative_config.num_speculative_tokens)
self.is_spec_decode = is_spec_decode self.is_spec_decode = is_spec_decode
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
...@@ -97,7 +102,7 @@ class InputBatch: ...@@ -97,7 +102,7 @@ class InputBatch:
# This buffer is not directly transferred to the GPU, so it does not # This buffer is not directly transferred to the GPU, so it does not
# need to be pinned. # need to be pinned.
self.token_ids_cpu_tensor = torch.zeros( self.token_ids_cpu_tensor = torch.zeros(
(max_num_reqs, max_model_len), (ori_max_num_reqs, max_model_len),
device="cpu", device="cpu",
dtype=torch.int32, dtype=torch.int32,
pin_memory=False, pin_memory=False,
...@@ -651,36 +656,44 @@ class InputBatch: ...@@ -651,36 +656,44 @@ class InputBatch:
or repeat_counts is not None or repeat_counts is not None
or self._sampling_metadata_is_expanded) or self._sampling_metadata_is_expanded)
if needs_rebuild: if needs_rebuild:
if repeat_counts is None: # if repeat_counts is None:
self.sampling_metadata = self._make_sampling_metadata() # self.sampling_metadata = self._make_sampling_metadata()
else: # else:
self.sampling_metadata = self._make_sampling_metadata_expanded( # self.sampling_metadata = self._make_sampling_metadata_expanded(
repeat_counts) # repeat_counts)
self.sampling_metadata = self._make_sampling_metadata(repeat_counts)
self._sampling_metadata_is_expanded = repeat_counts is not None self._sampling_metadata_is_expanded = repeat_counts is not None
# Expanded metadata is built on demand; do not cache a copy here. # Expanded metadata is built on demand; do not cache a copy here.
def _make_sampling_metadata(self) -> SamplingMetadata: def _make_sampling_metadata(self, repeat_counts: Optional[torch.Tensor] = None) -> SamplingMetadata:
num_reqs = self.num_reqs num_reqs = self.num_reqs
if not self.all_greedy: if not self.all_greedy:
temperature = copy_slice(self.temperature_cpu_tensor, temperature = copy_slice(self.temperature_cpu_tensor,
self.temperature, num_reqs) self.temperature, num_reqs,
repeat_counts)
else: else:
temperature = None temperature = None
if not self.no_top_p: if not self.no_top_p:
copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs) top_p = copy_slice(self.top_p_cpu_tensor, self.top_p, num_reqs, repeat_counts)
if not self.no_top_k: if not self.no_top_k:
copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs) top_k = copy_slice(self.top_k_cpu_tensor, self.top_k, num_reqs, repeat_counts)
frequency_penalties = None
presence_penalties = None
repetition_penalties = None
if not self.no_penalties: if not self.no_penalties:
# Since syncing these tensors is expensive only copy them # Since syncing these tensors is expensive only copy them
# if necessary i.e. if there are requests which require # if necessary i.e. if there are requests which require
# penalties to be applied during sampling. # penalties to be applied during sampling.
copy_slice(self.frequency_penalties_cpu_tensor, frequency_penalties = copy_slice(self.frequency_penalties_cpu_tensor,
self.frequency_penalties, num_reqs) self.frequency_penalties, num_reqs,
copy_slice(self.presence_penalties_cpu_tensor, repeat_counts)
self.presence_penalties, num_reqs) presence_penalties = copy_slice(self.presence_penalties_cpu_tensor,
copy_slice(self.repetition_penalties_cpu_tensor, self.presence_penalties, num_reqs,
self.repetition_penalties, num_reqs) repeat_counts)
repetition_penalties = copy_slice(self.repetition_penalties_cpu_tensor,
self.repetition_penalties, num_reqs,
repeat_counts)
needs_prompt_token_ids = (not self.no_penalties or needs_prompt_token_ids = (not self.no_penalties or
(self.num_reqs > 0 (self.num_reqs > 0
...@@ -697,9 +710,9 @@ class InputBatch: ...@@ -697,9 +710,9 @@ class InputBatch:
allowed_token_ids_mask: Optional[torch.Tensor] = None allowed_token_ids_mask: Optional[torch.Tensor] = None
if not self.no_allowed_token_ids: if not self.no_allowed_token_ids:
assert self.allowed_token_ids_mask is not None assert self.allowed_token_ids_mask is not None
copy_slice(self.allowed_token_ids_mask_cpu_tensor, allowed_token_ids_mask = copy_slice(self.allowed_token_ids_mask_cpu_tensor,
self.allowed_token_ids_mask, num_reqs) self.allowed_token_ids_mask, num_reqs,
allowed_token_ids_mask = self.allowed_token_ids_mask[:num_reqs] repeat_counts)
# Host-side summaries to avoid device synchronization in sampling # Host-side summaries to avoid device synchronization in sampling
# fast paths (e.g. reduced top-k/top-p sampling). # fast paths (e.g. reduced top-k/top-p sampling).
...@@ -714,14 +727,14 @@ class InputBatch: ...@@ -714,14 +727,14 @@ class InputBatch:
temperature=temperature, temperature=temperature,
all_greedy=self.all_greedy, all_greedy=self.all_greedy,
all_random=self.all_random, all_random=self.all_random,
top_p=None if self.no_top_p else self.top_p[:num_reqs], top_p=None if self.no_top_p else top_p,
top_k=None if self.no_top_k else self.top_k[:num_reqs], top_k=None if self.no_top_k else top_k,
generators=self.generators, generators=self.generators,
max_num_logprobs=self.max_num_logprobs, max_num_logprobs=self.max_num_logprobs,
prompt_token_ids=prompt_token_ids, prompt_token_ids=prompt_token_ids,
frequency_penalties=self.frequency_penalties[:num_reqs], frequency_penalties=None if self.no_penalties else frequency_penalties,
presence_penalties=self.presence_penalties[:num_reqs], presence_penalties=None if self.no_penalties else presence_penalties,
repetition_penalties=self.repetition_penalties[:num_reqs], repetition_penalties=None if self.no_penalties else repetition_penalties,
output_token_ids=cast(list[list[int]], self.req_output_token_ids), output_token_ids=cast(list[list[int]], self.req_output_token_ids),
no_penalties=self.no_penalties, no_penalties=self.no_penalties,
allowed_token_ids_mask=allowed_token_ids_mask, allowed_token_ids_mask=allowed_token_ids_mask,
......
...@@ -587,17 +587,19 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin): ...@@ -587,17 +587,19 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Refresh batch metadata with any pending updates. If we are in spec # Refresh batch metadata with any pending updates. If we are in spec
# decode + reject mode, also expand sampling metadata to token shape # decode + reject mode, also expand sampling metadata to token shape
# using per-request repeat counts. # using per-request repeat counts.
repeat_counts: Optional[torch.Tensor] = None repeat_counts = None
if envs.VLLM_REJECT_SAMPLE_OPT and \ if envs.VLLM_REJECT_SAMPLE_OPT and \
scheduler_output.scheduled_spec_decode_tokens: scheduler_output.scheduled_spec_decode_tokens:
num_reqs = self.input_batch.num_reqs repeat_counts = [1] * self.input_batch.num_reqs
num_draft_tokens = np.zeros(num_reqs, dtype=np.int32) #num_reqs = self.input_batch.num_reqs
#num_draft_tokens = np.zeros(num_reqs, dtype=np.int32)
for req_id, draft_token_ids in ( for req_id, draft_token_ids in (
scheduler_output.scheduled_spec_decode_tokens.items()): scheduler_output.scheduled_spec_decode_tokens.items()):
req_idx = self.input_batch.req_id_to_index.get(req_id) req_idx = self.input_batch.req_id_to_index.get(req_id)
if req_idx is not None: if req_idx is not None:
num_draft_tokens[req_idx] = len(draft_token_ids) repeat_counts[req_idx] += len(draft_token_ids)
repeat_counts = torch.from_numpy(num_draft_tokens).add_(1) repeat_counts = torch.tensor(repeat_counts, dtype=torch.int32, device="cpu")
self.input_batch.refresh_metadata(repeat_counts) self.input_batch.refresh_metadata(repeat_counts)
def _get_cumsum_and_arange( def _get_cumsum_and_arange(
...@@ -3437,8 +3439,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase): ...@@ -3437,8 +3439,8 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
) )
sampler_output.sampled_token_ids = output_token_ids sampler_output.sampled_token_ids = output_token_ids
else: else:
# sampling_metadata.all_greedy = True sampling_metadata.all_greedy = True
# sampling_metadata.all_random = False sampling_metadata.all_random = False
sampler_output = self.sampler( sampler_output = self.sampler(
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
......
...@@ -644,8 +644,8 @@ class V1ZeroModelRunner(GPUModelRunner): ...@@ -644,8 +644,8 @@ class V1ZeroModelRunner(GPUModelRunner):
) )
sampler_output.sampled_token_ids = output_token_ids sampler_output.sampled_token_ids = output_token_ids
else: else:
# sampling_metadata.all_greedy = True sampling_metadata.all_greedy = True
# sampling_metadata.all_random = False sampling_metadata.all_random = False
sampler_output = self.sampler( sampler_output = self.sampler(
logits=logits, logits=logits,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment