Commit 5ec2f551 authored by wangshankun's avatar wangshankun
Browse files

fix bug:flashinfer version compatibility

parent 6ac3cee7
...@@ -2,6 +2,10 @@ import torch ...@@ -2,6 +2,10 @@ import torch
try: try:
import flashinfer import flashinfer
from packaging import version
flashinfer_version = version.parse(flashinfer.__version__)
has_o_dtype = flashinfer_version >= version.parse("0.2.6.post1")
except ImportError: except ImportError:
flashinfer = None flashinfer = None
...@@ -29,7 +33,8 @@ def radial_attn( ...@@ -29,7 +33,8 @@ def radial_attn(
indptr = get_indptr_from_mask(mask, query) indptr = get_indptr_from_mask(mask, query)
indices = get_indices_from_mask(mask, query) indices = get_indices_from_mask(mask, query)
bsr_wrapper.plan(
kwargs = dict(
indptr=indptr, indptr=indptr,
indices=indices, indices=indices,
M=seqlen, M=seqlen,
...@@ -43,6 +48,10 @@ def radial_attn( ...@@ -43,6 +48,10 @@ def radial_attn(
kv_data_type=key.dtype, kv_data_type=key.dtype,
use_fp16_qk_reduction=True, use_fp16_qk_reduction=True,
) )
if has_o_dtype:
kwargs["o_data_type"] = query.dtype
bsr_wrapper.plan(**kwargs)
o = bsr_wrapper.run(query, key, value) o = bsr_wrapper.run(query, key, value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment