Unverified Commit 03a3becc authored by huismiling's avatar huismiling Committed by GitHub
Browse files

Cambricon MLUs support SDPA and flash_attn (#31102)

* add Cambricon MLUs support

* fix mlu device rng state

* up for quality check

* up mlu to support fp16

* fix mlu device dependency error

* fix mlu device dependency error

* enable mlu device for bf16

* fix mlu device memory tracker

* Cambricon support SDPA and flash_attn
parent ac946aac
......@@ -329,6 +329,9 @@ def is_torch_sdpa_available():
# NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons:
# - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259
# - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310
# NOTE: MLU is OK with non-contiguous inputs.
if is_torch_mlu_available():
return version.parse(_torch_version) >= version.parse("2.1.0")
# NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577
return version.parse(_torch_version) >= version.parse("2.1.1")
......@@ -795,7 +798,7 @@ def is_flash_attn_2_available():
# Let's add an extra check to see if cuda is available
import torch
if not torch.cuda.is_available():
if not (torch.cuda.is_available() or is_torch_mlu_available()):
return False
if torch.version.cuda:
......@@ -803,6 +806,8 @@ def is_flash_attn_2_available():
elif torch.version.hip:
# TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4")
elif is_torch_mlu_available():
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.3.3")
else:
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment