Commit d6dc122f authored by zhuwenwen's avatar zhuwenwen
Browse files

update the conditions for pad_v on v0

parent aaa89e82
......@@ -1043,10 +1043,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120 or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
if not current_platform.is_rocm():
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)
else:
self._pad_v = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count == 120
def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale,
return_softmax_lse, **kwargs):
maybe_padded_v = v
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment