Commit 5ec2f551 authored by wangshankun's avatar wangshankun
Browse files

fix bug:flashinfer version compatibility

parent 6ac3cee7
......@@ -2,6 +2,10 @@ import torch
try:
import flashinfer
from packaging import version
flashinfer_version = version.parse(flashinfer.__version__)
has_o_dtype = flashinfer_version >= version.parse("0.2.6.post1")
except ImportError:
flashinfer = None
......@@ -29,7 +33,8 @@ def radial_attn(
indptr = get_indptr_from_mask(mask, query)
indices = get_indices_from_mask(mask, query)
bsr_wrapper.plan(
kwargs = dict(
indptr=indptr,
indices=indices,
M=seqlen,
......@@ -43,6 +48,10 @@ def radial_attn(
kv_data_type=key.dtype,
use_fp16_qk_reduction=True,
)
if has_o_dtype:
kwargs["o_data_type"] = query.dtype
bsr_wrapper.plan(**kwargs)
o = bsr_wrapper.run(query, key, value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment