Unverified Commit bfa2c0bb authored by Andreas Karatzas's avatar Andreas Karatzas Committed by GitHub
Browse files

[ROCm][Bugfix] Fix RuntimeError in MMEncoderAttention by replacing .view() with .reshape() (#31203)


Signed-off-by: default avatarAndreas Karatzas <akaratza@amd.com>
parent f7900686
...@@ -19,7 +19,7 @@ def pytest_collection_modifyitems(config, items): ...@@ -19,7 +19,7 @@ def pytest_collection_modifyitems(config, items):
return return
# Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers # Disable Flash/MemEfficient SDP on ROCm to avoid HF Transformers
# accuracy issues # accuracy issues: https://github.com/vllm-project/vllm/issues/30167
# TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace # TODO: Remove once ROCm SDP accuracy issues are resolved on HuggingFace
torch.backends.cuda.enable_flash_sdp(False) torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False) torch.backends.cuda.enable_mem_efficient_sdp(False)
......
...@@ -136,7 +136,7 @@ class MMEncoderAttention(CustomOp): ...@@ -136,7 +136,7 @@ class MMEncoderAttention(CustomOp):
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
) )
if is_reshaped: if is_reshaped:
output = output.view(bsz, q_len, -1) output = output.reshape(bsz, q_len, -1)
return output return output
def _forward_fa( def _forward_fa(
...@@ -174,7 +174,7 @@ class MMEncoderAttention(CustomOp): ...@@ -174,7 +174,7 @@ class MMEncoderAttention(CustomOp):
fa_version=self._fa_version, fa_version=self._fa_version,
) )
if is_reshaped: if is_reshaped:
output = output.view(bsz, q_len, -1) output = output.reshape(bsz, q_len, -1)
return output return output
def forward_native( def forward_native(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment