Commit 18ab72c9 authored by Tong WU's avatar Tong WU Committed by LeiWang1999
Browse files

[BugFix] Fix import error in nsa examples when `fla.__version__ >=0.2.1` (#579)

* Update FLA import path for `prepare_token_indices`

* Update FLA import path for `prepare_token_indices`

* Compare versions via packaging.version.parse
parent 4c24a69e
# ruff: noqa
import torch
from typing import Optional, Union
from packaging.version import parse
import torch
import triton
from fla.ops.common.utils import prepare_token_indices
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
from reference import naive_nsa
from einops import rearrange
......
# ruff: noqa
import torch
from typing import Optional, Union
from packaging.version import parse
import tilelang
from tilelang import language as T
import tilelang.testing
from fla.ops.common.utils import prepare_token_indices
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from reference import naive_nsa
from einops import rearrange
......
# ruff: noqa
import torch
from typing import Optional, Union
from packaging.version import parse
import torch
import triton
import triton.language as tl
from fla.ops.common.utils import prepare_token_indices
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_bwd, autocast_custom_fwd, contiguous
from reference import naive_nsa
from einops import rearrange
......
# ruff: noqa
import torch
from typing import Optional, Union
from packaging.version import parse
import torch
import triton
import triton.language as tl
from fla.ops.common.utils import prepare_token_indices
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_fwd, contiguous
from reference import naive_nsa
from einops import rearrange
......
# ruff: noqa
import torch
from typing import Optional, Union
from packaging.version import parse
import torch
import triton
import triton.language as tl
from fla.ops.common.utils import prepare_token_indices
import fla
if parse(fla.__version__) < parse("0.2.1"):
from fla.ops.common.utils import prepare_token_indices
else:
from fla.ops.utils import prepare_token_indices
from fla.utils import autocast_custom_fwd, contiguous
from reference import naive_nsa
from einops import rearrange
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment