Commit 18ab72c9 authored by Tong WU's avatar Tong WU Committed by LeiWang1999
Browse files

[BugFix] Fix import error in nsa examples when `fla.__version__ >=0.2.1` (#579)

* Update FLA import path for `prepare_token_indices`

* Update FLA import path for `prepare_token_indices`

* Compare versions via packaging.version.parse
parent 4c24a69e
# ruff: noqa # ruff: noqa
import torch import torch
from typing import Optional, Union from typing import Optional, Union
from packaging.version import parse
import torch import torch
import triton import triton
from fla.ops.common.utils import prepare_token_indices import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
from reference import naive_nsa from reference import naive_nsa
from einops import rearrange from einops import rearrange
......
# ruff: noqa # ruff: noqa
import torch import torch
from typing import Optional, Union from typing import Optional, Union
from packaging.version import parse
import tilelang import tilelang
from tilelang import language as T from tilelang import language as T
import tilelang.testing import tilelang.testing
from fla.ops.common.utils import prepare_token_indices import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from reference import naive_nsa from reference import naive_nsa
from einops import rearrange from einops import rearrange
......
# ruff: noqa # ruff: noqa
import torch import torch
from typing import Optional, Union from typing import Optional, Union
from packaging.version import parse
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from fla.ops.common.utils import prepare_token_indices import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
from reference import naive_nsa from reference import naive_nsa
from einops import rearrange from einops import rearrange
......
# ruff: noqa # ruff: noqa
import torch import torch
from typing import Optional, Union from typing import Optional, Union
from packaging.version import parse
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from fla.ops.common.utils import prepare_token_indices import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_fwd, contiguous from fla.utils import autocast_custom_fwd, contiguous
from reference import naive_nsa from reference import naive_nsa
from einops import rearrange from einops import rearrange
......
# ruff: noqa # ruff: noqa
import torch import torch
from typing import Optional, Union from typing import Optional, Union
from packaging.version import parse
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from fla.ops.common.utils import prepare_token_indices import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_fwd, contiguous from fla.utils import autocast_custom_fwd, contiguous
from reference import naive_nsa from reference import naive_nsa
from einops import rearrange from einops import rearrange
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment