Commit 1217257c authored by zhuwenwen's avatar zhuwenwen
Browse files

fix run error

parent 8301427e
......@@ -226,14 +226,16 @@ def flash_mla_with_kvcache(
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits)
num_splits,
is_fp8_kvcache,
indices,
)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q,
......
......@@ -2062,8 +2062,6 @@ class FusedMoE(CustomOp):
router_logits=router_logits,
use_nn_moe=self.use_nn_moe,
use_fused_gate=self.use_fused_gate,
use_nn_moe=self.use_nn_moe,
use_fused_gate=self.use_fused_gate,
i_q=i_q,
i_s=i_s,
)
......
......@@ -181,16 +181,16 @@ def get_rope(
scaling_alpha,
dtype,
)
else:
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
else:
rotary_emb = DynamicNTKScalingRotaryEmbedding(
head_size,
rotary_dim,
max_position,
base,
is_neox_style,
scaling_factor,
dtype,
)
else:
raise ValueError(
"Dynamic rope scaling must contain either 'alpha' or 'factor' field"
......
......@@ -228,11 +228,10 @@ class RocmPlatform(Platform):
logger.info_once("Using Sparse MLA backend on V1 engine.")
return AttentionBackendEnum.ROCM_AITER_MLA_SPARSE.get_path()
if use_mla:
# if use_sparse:
if attn_selector_config.use_mla:
# if attn_selector_config.use_sparse:
# logger.info_once("Using Sparse MLA backend on V1 engine.")
# return ("vllm.v1.attention.backends.mla.flashmla_sparse."
# "FlashMLASparseBackend")
# return AttentionBackendEnum.FLASHMLA_SPARSE.get_path()
use_flashmla = selected_backend == AttentionBackendEnum.FLASHMLA or envs.VLLM_USE_FLASH_MLA
use_triton = selected_backend == AttentionBackendEnum.TRITON_MLA or (
......
......@@ -56,6 +56,7 @@ from vllm.v1.attention.backends.utils import (
get_kv_cache_layout,
)
from vllm.v1.kv_cache_interface import AttentionSpec
import vllm.envs as envs
logger = init_logger(__name__)
......
......@@ -81,8 +81,8 @@ class Worker(WorkerBase):
)
# configure float32 matmul precision according to vLLM env.
precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
torch.backends.cuda.matmul.fp32_precision = precision
# precision = envs.VLLM_FLOAT32_MATMUL_PRECISION
# torch.backends.cuda.matmul.fp32_precision = precision
if self.model_config.trust_remote_code:
# note: lazy import to avoid importing torch before initializing
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment