Commit 83f2f396 authored by 王敏's avatar 王敏
Browse files

同步0.9.2-ds分支代码

parents d2e57a90 20605c42
...@@ -282,6 +282,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -282,6 +282,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
use_nn_moe: Optional[bool] = False, use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None, routed_scaling_factor: Optional[float] = None,
use_fused_gate: Optional[bool] = False, use_fused_gate: Optional[bool] = False,
shared_output: Optional[torch.Tensor] = None,
**_ **_
) -> torch.Tensor: ) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.fused_moe import fused_experts
...@@ -324,6 +325,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod: ...@@ -324,6 +325,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
) )
def select_gemm_impl( def select_gemm_impl(
......
import torch import torch
import numpy as np import numpy as np
try: try:
from lightop import awq_marlin_repack_w4a8 from lightop import awq_marlin_repack_w4a8
use_lightop = True use_lightop = False
except Exception: except Exception:
use_lightop = False use_lightop = False
......
...@@ -392,15 +392,21 @@ def apply_int8_linear( ...@@ -392,15 +392,21 @@ def apply_int8_linear(
azp_adj: Optional[torch.Tensor] = None, azp_adj: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
w8a8_strategy:Optional[int]=0, w8a8_strategy:Optional[int]=0,
input_quant_args: Optional[list[torch.Tensor]] = None
): ):
# ops.scaled_int8_quant supports both dynamic and static quant. # ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x. # * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale. # * static, layer.input_scale is scalar and x_scale is input_scale.
if envs.USE_FUSED_RMS_QUANT and input_quant_args is not None:
assert len(input_quant_args) == 2
x_zp =None
x_q, x_scale = input_quant_args
else: # not USE_FUSED_RMS_QUANT
symmetric = azp_adj is None symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True: if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input) x_q, x_scale=per_token_quant_int8(input)
x_zp =None x_zp =None
else: else:
x_q, x_scale, x_zp = ops.scaled_int8_quant(input, x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale, input_scale,
......
...@@ -37,6 +37,8 @@ from transformers import PretrainedConfig ...@@ -37,6 +37,8 @@ from transformers import PretrainedConfig
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
import vllm.envs as envs
from vllm.utils import direct_register_custom_op
if current_platform.is_cuda(): if current_platform.is_cuda():
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
...@@ -900,6 +902,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -900,6 +902,24 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
cache = torch.cat((cos, sin), dim=-1) cache = torch.cat((cos, sin), dim=-1)
return cache return cache
def rotary_embedding_deepseek_fuse(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
head_size: int, cos_sin_cache: torch.Tensor,
is_neox_style: bool) -> None:
from lightop import op
op.rotary_embedding_deepseek_fuse(positions, query, key, head_size, cos_sin_cache, is_neox_style)
def rotary_embedding_deepseek_fuse_fake(positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor,
head_size: int, cos_sin_cache: torch.Tensor,
is_neox_style: bool) -> None:
pass
direct_register_custom_op(
op_name="rotary_embedding_deepseek_fuse",
op_func=rotary_embedding_deepseek_fuse,
mutates_args=[],
fake_impl=rotary_embedding_deepseek_fuse_fake,
)
def forward( def forward(
self, self,
positions: torch.Tensor, positions: torch.Tensor,
...@@ -938,6 +958,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding): ...@@ -938,6 +958,10 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
BLOCK_SIZE=BLOCK_SIZE, BLOCK_SIZE=BLOCK_SIZE,
num_warps=1) num_warps=1)
# if envs.VLLM_USE_LIGHTOP:
if False:
torch.ops.vllm.rotary_embedding_deepseek_fuse(positions, query, key, self.head_size, self.cos_sin_cache, self.is_neox_style)
else:
call(query) call(query)
call(key) call(key)
return query, key return query, key
......
...@@ -238,14 +238,28 @@ def get_model_architecture( ...@@ -238,14 +238,28 @@ def get_model_architecture(
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
else: else:
os.environ['LLAMA_NN'] = '1' os.environ['LLAMA_NN'] = '1'
if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0': if (architectures == ['BloomForCausalLM'] or architectures == ['FalconForCausalLM']) or os.getenv('LM_NN') == '0':
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
else: else:
os.environ['LM_NN'] = '1' os.environ['LM_NN'] = '1'
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1'
if os.getenv('GEMM_PAD') != '1': if os.getenv('GEMM_PAD') != '1':
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1': if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0' os.environ['FA_PAD'] = '0'
else:
if architectures in [['DeepseekV3ForCausalLM'], ['DeepSeekMTPModel']]:
if not envs.is_set("VLLM_USE_LIGHTOP"):
os.environ['VLLM_USE_LIGHTOP'] = '1'
if not envs.is_set("VLLM_USE_OPT_CAT"):
os.environ['VLLM_USE_OPT_CAT'] = '1'
# awq相关配置 # awq相关配置
try: try:
if os.getenv('AWQ_MOE_SZ') == None: if os.getenv('AWQ_MOE_SZ') == None:
......
...@@ -274,6 +274,7 @@ class RocmPlatform(Platform): ...@@ -274,6 +274,7 @@ class RocmPlatform(Platform):
logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)") logger.info_once("Using Flash Attention backend on V1 engine. (only supports block size 64)")
return FLASH_ATTN_V1 return FLASH_ATTN_V1
else: else:
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
logger.info_once("Using Triton backend on V1 engine.") logger.info_once("Using Triton backend on V1 engine.")
return TRITON_ATTN_VLLM_V1 return TRITON_ATTN_VLLM_V1
......
...@@ -183,8 +183,8 @@ STR_DTYPE_TO_TORCH_DTYPE = { ...@@ -183,8 +183,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16, "bfloat16": torch.bfloat16,
"float": torch.float, "float": torch.float,
"fp8": torch.uint8, "fp8": torch.uint8,
# "fp8_e4m3": torch.uint8, "fp8_e4m3": torch.uint8,
# "fp8_e5m2": torch.uint8, "fp8_e5m2": torch.uint8,
"int8": torch.int8, "int8": torch.int8,
} }
......
...@@ -216,7 +216,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder, ...@@ -216,7 +216,6 @@ from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata) CommonAttentionMetadata)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
try: try:
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
...@@ -894,6 +893,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -894,6 +893,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
q: torch.Tensor, q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32),
): ):
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
prefill_metadata = attn_metadata.prefill prefill_metadata = attn_metadata.prefill
...@@ -913,6 +913,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -913,6 +913,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i], cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
batch_size=attn_metadata.num_prefills, batch_size=attn_metadata.num_prefills,
seq_starts=prefill_metadata.chunked_context.starts[i], seq_starts=prefill_metadata.chunked_context.starts[i],
kv_dtype=self.kv_cache_dtype,
scale=kv_scale,
) )
kv_c_normed = workspace[:toks]\ kv_c_normed = workspace[:toks]\
...@@ -925,8 +927,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -925,8 +927,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\ k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT: if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024: if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2) dim=2)
else: else:
...@@ -976,6 +979,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -976,6 +979,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_pe: torch.Tensor, k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata, attn_metadata: MLACommonMetadata,
kv_scale=torch.tensor(1.0, dtype=torch.float32),
) -> torch.Tensor: ) -> torch.Tensor:
assert attn_metadata.prefill is not None assert attn_metadata.prefill is not None
...@@ -989,8 +993,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -989,8 +993,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
k_nope, v = kv_nope\ k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
if envs.VLLM_USE_TRITON_CAT: if envs.VLLM_USE_OPT_CAT:
if k_nope.shape[0] > 1024: if k_nope.shape[0] > 1024:
from vllm.v1.attention.backends.mla.test_concat import lightop_concat_prefill_helper
k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)), k = lightop_concat_prefill_helper(k_nope, k_pe.expand((*k_nope.shape[:-1], -1)),
dim=2) dim=2)
else: else:
...@@ -1015,7 +1020,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1015,7 +1020,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_context: if has_context:
suffix_output, suffix_lse = output suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \ context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata) q, kv_c_and_k_pe_cache, attn_metadata, kv_scale)
output = torch.empty_like(suffix_output) output = torch.empty_like(suffix_output)
merge_attn_states( merge_attn_states(
...@@ -1104,7 +1109,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1104,7 +1109,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_prefill: if has_prefill:
output[num_decode_tokens:] = self._forward_prefill( output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata) attn_metadata, kv_scale=layer._k_scale)
if has_decode: if has_decode:
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
......
...@@ -20,7 +20,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend, ...@@ -20,7 +20,7 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable from vllm.v1.worker.block_table import BlockTable
from vllm import envs from vllm import envs
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -150,7 +150,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -150,7 +150,8 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
"FlashMLAImpl") "FlashMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype): if is_quantized_kv_cache(self.kv_cache_dtype):
if self.kv_cache_dtype != "fp8": if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
return
raise NotImplementedError( raise NotImplementedError(
"FlashMLA with other KV cache not yet supported") "FlashMLA with other KV cache not yet supported")
...@@ -166,8 +167,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -166,8 +167,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if envs.VLLM_USE_TRITON_CAT: if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] <= 1024: if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
q = concat_helper_decode(q_nope, q_pe, dim=2)\ q = concat_helper_decode(q_nope, q_pe, dim=2)\
.unsqueeze(1) .unsqueeze(1)
else: else:
......
...@@ -5,7 +5,10 @@ from functools import reduce ...@@ -5,7 +5,10 @@ from functools import reduce
import pytest import pytest
import torch import torch
import math import math
from lightop import ds_cat import vllm.envs as envs
if envs.VLLM_USE_OPT_CAT:
from lightop import ds_cat
def test_concat_Acc_prefill(shape_pair, dim): def test_concat_Acc_prefill(shape_pair, dim):
......
...@@ -1047,16 +1047,14 @@ class Scheduler(SchedulerInterface): ...@@ -1047,16 +1047,14 @@ class Scheduler(SchedulerInterface):
for req in itertools.chain(running_reqs, resumed_reqs): for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id req_id = req.request_id
req_ids.append(req_id) req_ids.append(req_id)
num_tokens = (num_scheduled_tokens[req_id] - num_tokens = req.num_generated_token_ids
len(spec_decode_tokens.get(req_id, ())))
if self.use_pp: if self.use_pp:
# When using PP, the scheduler sends the sampled tokens back, # When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first- # because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't # stage worker and the last-stage worker. Otherwise, we don't
# need to send the sampled tokens back because the model runner # need to send the sampled tokens back because the model runner
# will cache them. # will cache them.
token_ids = req.all_token_ids[req.num_computed_tokens:req. token_ids = req.all_token_ids[-num_tokens:]
num_computed_tokens + num_tokens]
new_token_ids.append(token_ids) new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id]) new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens) num_computed_tokens.append(req.num_computed_tokens)
...@@ -1190,6 +1188,7 @@ class Scheduler(SchedulerInterface): ...@@ -1190,6 +1188,7 @@ class Scheduler(SchedulerInterface):
scheduled_spec_token_ids = ( scheduled_spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id)) scheduler_output.scheduled_spec_decode_tokens.get(req_id))
request.num_generated_token_ids = 1
if scheduled_spec_token_ids: if scheduled_spec_token_ids:
# num_computed_tokens represents the number of tokens # num_computed_tokens represents the number of tokens
# processed in the current step, considering scheduled # processed in the current step, considering scheduled
...@@ -1197,9 +1196,11 @@ class Scheduler(SchedulerInterface): ...@@ -1197,9 +1196,11 @@ class Scheduler(SchedulerInterface):
# num_computed_tokens is decreased by the number of rejected # num_computed_tokens is decreased by the number of rejected
# tokens, where is given by: # tokens, where is given by:
# len(scheduled_spec_token_ids) + 1 - len(generated_token_ids). # len(scheduled_spec_token_ids) + 1 - len(generated_token_ids).
num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 - num_tokens_rejected = (len(scheduled_spec_token_ids) + 1 -
len(generated_token_ids)) len(generated_token_ids))
request.num_computed_tokens -= num_tokens_rejected request.num_computed_tokens -= num_tokens_rejected
request.num_generated_token_ids = len(generated_token_ids)
spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats = self.make_spec_decoding_stats(
spec_decoding_stats, spec_decoding_stats,
num_draft_tokens=len(scheduled_spec_token_ids), num_draft_tokens=len(scheduled_spec_token_ids),
......
...@@ -79,6 +79,7 @@ class Request: ...@@ -79,6 +79,7 @@ class Request:
self._all_token_ids: list[int] = self.prompt_token_ids.copy() self._all_token_ids: list[int] = self.prompt_token_ids.copy()
self.spec_token_ids: list[int] = [] self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0 self.num_computed_tokens = 0
self.num_generated_token_ids = 0
self.cache_salt: Optional[str] = cache_salt self.cache_salt: Optional[str] = cache_salt
# Multi-modal related # Multi-modal related
......
...@@ -531,10 +531,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -531,10 +531,10 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if not is_last_rank: if not is_last_rank:
# Add new_token_ids to token_ids_cpu. # Add new_token_ids to token_ids_cpu.
start_token_index = num_computed_tokens start_token_index = num_computed_tokens
end_token_index = num_computed_tokens + len(new_token_ids) end_token_index = num_computed_tokens + 1
self.input_batch.token_ids_cpu[ self.input_batch.token_ids_cpu[
req_index, req_index,
start_token_index:end_token_index] = new_token_ids start_token_index:end_token_index] = new_token_ids[-1]
self.input_batch.num_tokens_no_spec[ self.input_batch.num_tokens_no_spec[
req_index] = end_token_index req_index] = end_token_index
self.input_batch.num_tokens[req_index] = end_token_index self.input_batch.num_tokens[req_index] = end_token_index
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment