Unverified Commit e58294dd authored by JGSweets's avatar JGSweets Committed by GitHub
Browse files

[Bugfix] Add verbose error if scipy is missing for blocksparse attention (#5695)

parent f1e15da6
...@@ -6,7 +6,14 @@ from functools import lru_cache ...@@ -6,7 +6,14 @@ from functools import lru_cache
import torch import torch
import triton import triton
from scipy import sparse
try:
from scipy import sparse
except ImportError as err:
raise ImportError("Please install scipy via "
"`pip install scipy` to use "
"BlockSparseAttention in "
"models such as Phi-3.") from err
def dense_to_crow_col(x: torch.Tensor): def dense_to_crow_col(x: torch.Tensor):
...@@ -77,11 +84,11 @@ def _get_sparse_attn_mask_homo_head( ...@@ -77,11 +84,11 @@ def _get_sparse_attn_mask_homo_head(
): ):
""" """
:return: a tuple of 3: :return: a tuple of 3:
- tuple of crow_indices, col_indices representation - tuple of crow_indices, col_indices representation
of CSR format. of CSR format.
- block dense mask - block dense mask
- all token dense mask (be aware that it can be - all token dense mask (be aware that it can be
OOM if it is too big) if `return_dense==True`, OOM if it is too big) if `return_dense==True`,
otherwise, None otherwise, None
""" """
with torch.no_grad(): with torch.no_grad():
...@@ -148,10 +155,10 @@ def get_sparse_attn_mask( ...@@ -148,10 +155,10 @@ def get_sparse_attn_mask(
:param dense_mask_type: "binary" (0 for skip token, 1 for others) :param dense_mask_type: "binary" (0 for skip token, 1 for others)
or "bias" (-inf for skip token, 0 or others) or "bias" (-inf for skip token, 0 or others)
:return: a tuple of 3: :return: a tuple of 3:
- tuple of crow_indices, col_indices representation - tuple of crow_indices, col_indices representation
of CSR format. of CSR format.
- block dense mask - block dense mask
- all token dense mask (be aware that it can be OOM if it - all token dense mask (be aware that it can be OOM if it
is too big) if `return_dense==True`, otherwise, None is too big) if `return_dense==True`, otherwise, None
""" """
assert dense_mask_type in ("binary", "bias") assert dense_mask_type in ("binary", "bias")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment