Commit a54eca71 authored by zhuwenwen's avatar zhuwenwen
Browse files

update triton_mla.py

parent 30e0b082
......@@ -8,7 +8,7 @@ from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from vllm.multimodal import MultiModalPlaceholderMap
from .triton_config import get_nearest_config, get_attention_mla_configs, get_config
from .triton_config import get_nearest_config, get_attention_mla_configs, get_config, get_attention_mla_configs_json
try:
from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
......@@ -687,7 +687,7 @@ class TritonMLAImpl(MLACommonImpl[TritonMLAMetadata]):
"are not implemented for "
"TritonMLAImpl")
self.attn_configs = get_attention_mla_configs(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
self.attn_configs = get_attention_mla_configs_json(self.num_heads, 1, self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank, "fp16")
def _forward_prefill(
self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment