Commit 83f2f396 authored by 王敏's avatar 王敏
Browse files

同步0.9.2-ds分支代码

parents d2e57a90 20605c42
......@@ -282,6 +282,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
**_
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
......@@ -324,6 +325,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
def select_gemm_impl(
......
import torch
import numpy as np
try:
from lightop import awq_marlin_repack_w4a8
use_lightop = True
use_lightop = False
except Exception:
use_lightop = False
......@@ -90,4 +89,4 @@ def w4a8_weight_repack_impl(input):
w_marlin_list.append(w_marlin_in)
output = torch.stack(w_marlin_list, dim=0)
return output
\ No newline at end of file
return output
......@@ -392,20 +392,26 @@ def apply_int8_linear(
azp_adj: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
w8a8_strategy:Optional[int]=0,
input_quant_args: Optional[list[torch.Tensor]] = None
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input)
x_zp =None
else:
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
input_zero_point,
symmetric=symmetric)
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
x_zp =None
x_q, x_scale = input_quant_args
else: # not USE_FUSED_RMS_QUANT
symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input)
x_zp =None
else:
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
input_zero_point,
symmetric=symmetric)
if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token
......
......@@ -37,6 +37,8 @@ from transformers import PretrainedConfig
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
import vllm.envs as envs
from vllm.utils import direct_register_custom_op
if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
......@@ -900,6 +902,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
cache = torch.cat((cos, sin), dim=-1)
return cache
def rotary_embedding_deepseek_fuse(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
head_size: int, cos_sin_cache: torch.Tensor,
is_neox_style: bool) -> None:
from lightop import op
op.rotary_embedding_deepseek_fuse(positions, query, key, head_size, cos_sin_cache, is_neox_style)
def rotary_embedding_deepseek_fuse_fake(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
head_size: int, cos_sin_cache: torch.Tensor,
is_neox_style: bool) -> None:
pass
direct_register_custom_op(
op_name="rotary_embedding_deepseek_fuse",
op_func=rotary_embedding_deepseek_fuse,
mutates_args=[],
fake_impl=rotary_embedding_deepseek_fuse_fake,
)
def forward(
self,
positions: torch.Tensor,
......@@ -938,8 +958,12 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
BLOCK_SIZE=BLOCK_SIZE,
num_warps=1)
call(query)
call(key)
# if envs.VLLM_USE_LIGHTOP:
if False:
torch.ops.vllm.rotary_embedding_deepseek_fuse(positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style)
else:
call(query)
call(key)
return query, key
else:
query_rot = query[..., :self.rotary_dim]
......
......@@ -238,14 +238,28 @@ def get_model_architecture(
os.environ['LLAMA_NN'] = '0'
else:
os.environ['LLAMA_NN'] = '1'
if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0':
os.environ['LM_NN'] = '0'
else:
os.environ['LM_NN'] = '1'
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1'
if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0'
else:
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1'
# awq相关配置
try:
if os.getenv('AWQ_MOE_SZ') == None:
......
......@@ -274,6 +274,7 @@ class RocmPlatform(Platform):
logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
return FLASH_ATTN_V1
else:
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1
......
......@@ -183,8 +183,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.uint8,
# "fp8_e4m3": torch.uint8,
# "fp8_e5m2": torch.uint8,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
"int8": torch.int8,
}
......
......@@ -216,7 +216,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
......@@ -894,6 +893,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32),
):
assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill
......@@ -913,6 +913,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
batch_size=attn_metadata.num_prefills,
seq_starts=prefill_metadata.chunked_context.starts[i],
kv_dtype=self.kv_cache_dtype,
scale=kv_scale,
)
kv_c_normed = workspace[:toks]\
......@@ -925,8 +927,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT:
if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2)
else:
......@@ -976,6 +979,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32),
) -> torch.Tensor:
assert attn_metadata.prefill is not None
......@@ -989,8 +993,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT:
if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2)
else:
......@@ -1015,7 +1020,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_context:
suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata)
q, kv_c_and_k_pe_cache, attn_metadata, kv_scale)
output = torch.empty_like(suffix_output)
merge_attn_states(
......@@ -1104,7 +1109,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_prefill:
output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
attn_metadata, kv_scale=layer._k_scale)
if has_decode:
assert attn_metadata.decode is not None
......
......@@ -20,7 +20,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable
from vllm import envs
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
logger = init_logger(__name__)
......@@ -150,9 +150,10 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
if self.kv_cache_dtype != "fp8":
raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
return
raise NotImplementedError(
"FlashMLA with other KV cache not yet supported")
def _forward_decode(
self,
......@@ -166,8 +167,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if envs.VLLM_USE_TRITON_CAT:
if q_nope.shape[0] <= 1024:
if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1)
else:
......
......@@ -5,7 +5,10 @@ from functools import reduce
import pytest
import torch
import math
from lightop import ds_cat
import vllm.envs as envs
if envs.VLLM_USE_OPT_CAT:
from lightop import ds_cat
def test_concat_Acc_prefill(shape_pair, dim):
......
......@@ -1047,16 +1047,14 @@ class Scheduler(SchedulerInterface):
for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id
req_ids.append(req_id)
num_tokens = (num_scheduled_tokens[req_id] -
len(spec_decode_tokens.get(req_id, ())))
num_tokens = req.num_generated_token_ids
if self.use_pp:
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner
# will cache them.
token_ids = req.all_token_ids[req.num_computed_tokens:req.
num_computed_tokens + num_tokens]
token_ids = req.all_token_ids[-num_tokens:]
new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens)
......@@ -1190,6 +1188,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id))
request.num_generated_token_ids = 1
if scheduled_spec_token_ids:
# num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled
......@@ -1197,9 +1196,11 @@ class Scheduler(SchedulerInterface):
# num_computed_tokens is decreased by the number of rejected
# tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected
request.num_generated_token_ids = len(generated_token_ids)
spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids),
......
......@@ -79,6 +79,7 @@ class Request:
self._all_token_ids: list[int] = self.prompt_token_ids.copy()
self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0
self.num_generated_token_ids = 0
self.cache_salt: Optional[str] = cache_salt
# Multi-modal related
......
......@@ -499,8 +499,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
elif num_new_tokens > 0:
req_state.output_token_ids.extend(
new_token_ids[-num_new_tokens:])
if len(spec_token_ids) > 0:
req_state.spec_token_ids = spec_token_ids
if len(spec_token_ids) > 0:
req_state.spec_token_ids = spec_token_ids
# Update the block IDs.
if not resumed_from_preemption:
......@@ -531,10 +531,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not is_last_rank:
# Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids)
end_token_index = num_computed_tokens + 1
self.input_batch.token_ids_cpu[
req_index,
start_token_index:end_token_index] = new_token_ids
start_token_index:end_token_index] = new_token_ids[-1]
self.input_batch.num_tokens_no_spec[
req_index] = end_token_index
self.input_batch.num_tokens[req_index] = end_token_index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment