Commit 263f45a4 authored by zhuwenwen's avatar zhuwenwen
Browse files

update VLLM_USE_OPT_RESHAPE_AND_CACHE to support bf16 and qwen3-dense

parent ac28ab22
...@@ -211,6 +211,9 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], ...@@ -211,6 +211,9 @@ def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module],
os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1' os.environ['VLLM_USE_FUSE_SILU_AND_MUL'] = '1'
if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"): if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1' os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '1'
# if architectures in [['Qwen3ForCausalLM']]:
# if not envs.is_set("VLLM_USE_OPT_RESHAPE_AND_CACHE"):
# os.environ['VLLM_USE_OPT_RESHAPE_AND_CACHE'] = '0'
if architectures in [['DeepseekV32ForCausalLM']]: if architectures in [['DeepseekV32ForCausalLM']]:
if not envs.is_set("VLLM_USE_V32_ENCODE"): if not envs.is_set("VLLM_USE_V32_ENCODE"):
......
...@@ -898,7 +898,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -898,7 +898,7 @@ class FlashAttentionImpl(AttentionImpl):
# op uses the slot_mapping's shape to determine the number of # op uses the slot_mapping's shape to determine the number of
# actual tokens. # actual tokens.
if current_platform.is_rocm(): if current_platform.is_rocm():
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE and key.dtype == value.dtype == torch.float16: if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE:
from lightop import reshape_and_cache_cuda from lightop import reshape_and_cache_cuda
reshape_and_cache_cuda( reshape_and_cache_cuda(
key, key,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment