Commit 0ac26aec authored by zhuwenwen's avatar zhuwenwen
Browse files

update fused_rms_norm_rope_contiguous params

set VLLM_REJECT_SAMPLE_OPT=1 for dpsk-v3
parent e9532d9e
......@@ -255,8 +255,11 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_CAT_MLA"):
os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
# os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
else:
if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1'
......@@ -291,6 +294,8 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_CAT_MLA"):
os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
# os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
else:
......
......@@ -1132,8 +1132,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
decode_q = q[:num_decode_tokens]
prefill_k_pe = k_pe[num_decode_tokens:]
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
decode_q = q[:num_decode_tokens]
......@@ -1166,11 +1164,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str = self.kv_cache_dtype
fused_rms_norm_rope_contiguous(
positions,
positions[:num_actual_toks, ...],
q,
k_pe.squeeze(1),
k_c_normed, # not normed
key_normed, # normed
key_normed[:num_actual_toks, ...], # normed
weight,
cos_sin_cache,
attn_metadata.slot_mapping.flatten(),
......@@ -1183,7 +1181,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_prefill:
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
prefill_k_c_normed = key_normed[num_decode_tokens:]
prefill_k_c_normed = key_normed[:num_actual_toks, ...]
prefill_k_c_normed = prefill_k_c_normed[num_decode_tokens:]
output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata, kv_scale=layer._k_scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment