Commit de889cb6 authored by zhuwenwen's avatar zhuwenwen
Browse files

sync v0.15.1

parent c721b814
......@@ -156,8 +156,16 @@ class MistralDecoderLayer(LlamaDecoderLayer):
)
self.layer_idx = int(prefix.split(sep=".")[-1])
quant_config = self.get_quant_config(vllm_config)
config = config or vllm_config.model_config.hf_config
do_fusion = getattr(
quant_config, "enable_quantization_scaling_fusion", False
) and vllm_config.cache_config.cache_dtype.startswith("fp8")
if do_fusion:
self.input_layernorm.quant_scaling_from = self.self_attn.qkv_proj
self.post_attention_layernorm.quant_scaling_from = self.mlp.gate_up_proj
if getattr(config, "ada_rms_norm_t_cond", False):
self.ada_rms_norm_t_cond = nn.Sequential(
ColumnParallelLinear(
......@@ -339,4 +347,4 @@ class MistralForCausalLM(LlamaForCausalLM):
elif item in mapping and mapping[item] not in name:
name = name.replace(item, mapping[item])
return name, loaded_weight
return name, loaded_weight
\ No newline at end of file
......@@ -284,6 +284,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
scheduler_metadata.tile_scheduler_metadata = tile_scheduler_metadata
scheduler_metadata.num_splits = num_splits
if self.kv_cache_dtype.startswith("fp8"):
o, lse = flash_mla_with_kvcache_fp8(
q=q,
......
......@@ -302,7 +302,6 @@ def chunked_prefill_paged_decode(
block_size = value_cache.shape[3]
num_seqs = len(seq_lens)
num_query_heads = query.shape[1]
# key may be None in cross-attention decode (already cached from encoder)
num_kv_heads = key.shape[1]
num_queries_per_kv = query.shape[1] // key.shape[1]
head_size = query.shape[2]
......
......@@ -22,7 +22,7 @@ else:
if current_platform.is_cuda():
try:
import vllm._flashmla_extension_C # noqa: F401
_flashmla_extension_C_AVAILABLE = True
except ImportError:
_flashmla_extension_C_AVAILABLE = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment