Unverified Commit 093cd3f0 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

fix dispatch_attention_fn check (#12636)

* fix

* fix
parent aecf0c53
...@@ -383,12 +383,18 @@ def _check_shape( ...@@ -383,12 +383,18 @@ def _check_shape(
attn_mask: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
# Expected shapes:
# query: (batch_size, seq_len_q, num_heads, head_dim)
# key: (batch_size, seq_len_kv, num_heads, head_dim)
# value: (batch_size, seq_len_kv, num_heads, head_dim)
# attn_mask: (seq_len_q, seq_len_kv) or (batch_size, seq_len_q, seq_len_kv)
# or (batch_size, num_heads, seq_len_q, seq_len_kv)
if query.shape[-1] != key.shape[-1]: if query.shape[-1] != key.shape[-1]:
raise ValueError("Query and key must have the same last dimension.") raise ValueError("Query and key must have the same head dimension.")
if query.shape[-2] != value.shape[-2]: if key.shape[-3] != value.shape[-3]:
raise ValueError("Query and value must have the same second to last dimension.") raise ValueError("Key and value must have the same sequence length.")
if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: if attn_mask is not None and attn_mask.shape[-1] != key.shape[-3]:
raise ValueError("Attention mask must match the key's second to last dimension.") raise ValueError("Attention mask must match the key's sequence length.")
# ===== Helper functions ===== # ===== Helper functions =====
......
...@@ -42,7 +42,7 @@ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules" ...@@ -42,7 +42,7 @@ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
DIFFUSERS_REQUEST_TIMEOUT = 60 DIFFUSERS_REQUEST_TIMEOUT = 60
DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native") DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0").upper() in ENV_VARS_TRUE_VALUES
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment