Commit 0ac26aec authored by zhuwenwen's avatar zhuwenwen
Browse files

update fused_rms_norm_rope_contiguous params

set VLLM_REJECT_SAMPLE_OPT=1 for dpsk-v3
parent e9532d9e
...@@ -255,8 +255,11 @@ def get_model_architecture( ...@@ -255,8 +255,11 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1' os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_CAT_MLA"): if not envs.is_set("VLLM_USE_CAT_MLA"):
os.environ['VLLM_USE_CAT_MLA'] = '1' os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"): # if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
# os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1' # os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
else: else:
if not envs.is_set("VLLM_USE_PD_SPLIT"): if not envs.is_set("VLLM_USE_PD_SPLIT"):
os.environ['VLLM_USE_PD_SPLIT'] = '1' os.environ['VLLM_USE_PD_SPLIT'] = '1'
...@@ -291,6 +294,8 @@ def get_model_architecture( ...@@ -291,6 +294,8 @@ def get_model_architecture(
os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1' os.environ['VLLM_USE_LIGHTOP_FILL_MOE_ALIGN'] = '1'
if not envs.is_set("VLLM_USE_CAT_MLA"): if not envs.is_set("VLLM_USE_CAT_MLA"):
os.environ['VLLM_USE_CAT_MLA'] = '1' os.environ['VLLM_USE_CAT_MLA'] = '1'
if not envs.is_set("VLLM_REJECT_SAMPLE_OPT"):
os.environ['VLLM_REJECT_SAMPLE_OPT'] = '1'
# if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"): # if not envs.is_set("VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT"):
# os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1' # os.environ['VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT'] = '1'
else: else:
......
...@@ -1132,8 +1132,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1132,8 +1132,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
decode_q = q[:num_decode_tokens]
prefill_k_pe = k_pe[num_decode_tokens:] prefill_k_pe = k_pe[num_decode_tokens:]
if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if not envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
decode_q = q[:num_decode_tokens] decode_q = q[:num_decode_tokens]
...@@ -1166,11 +1164,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1166,11 +1164,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
kv_cache_dtype_str = self.kv_cache_dtype kv_cache_dtype_str = self.kv_cache_dtype
fused_rms_norm_rope_contiguous( fused_rms_norm_rope_contiguous(
positions, positions[:num_actual_toks, ...],
q, q,
k_pe.squeeze(1), k_pe.squeeze(1),
k_c_normed, # not normed k_c_normed, # not normed
key_normed, # normed key_normed[:num_actual_toks, ...], # normed
weight, weight,
cos_sin_cache, cos_sin_cache,
attn_metadata.slot_mapping.flatten(), attn_metadata.slot_mapping.flatten(),
...@@ -1183,7 +1181,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1183,7 +1181,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_prefill: if has_prefill:
if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT: if envs.VLLM_USE_LIGHTOP_RMS_ROPE_CONCAT:
prefill_k_c_normed = key_normed[num_decode_tokens:] prefill_k_c_normed = key_normed[:num_actual_toks, ...]
prefill_k_c_normed = prefill_k_c_normed[num_decode_tokens:]
output[num_decode_tokens:] = self._forward_prefill( output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache, prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata, kv_scale=layer._k_scale) attn_metadata, kv_scale=layer._k_scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment