Unverified Commit 38b792aa authored by flybird1111's avatar flybird1111 Committed by GitHub
Browse files

[coloattention] fix import error (#4380)

fixed an import error
parent 25c57b9f
from .mha import ColoAttention
__all__ = ['ColoAttention']
...@@ -9,7 +9,7 @@ from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN ...@@ -9,7 +9,7 @@ from colossalai.kernel.cuda_native.mha.mem_eff_attn import HAS_MEM_EFF_ATTN
from colossalai.testing import clear_cache_before_run, parameterize from colossalai.testing import clear_cache_before_run, parameterize
if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN: if HAS_MEM_EFF_ATTN or HAS_FLASH_ATTN:
from colossalai.kernel.cuda_native.mha.mha import ColoAttention from colossalai.kernel.cuda_native import ColoAttention
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
DTYPE = [torch.float16, torch.bfloat16, torch.float32] DTYPE = [torch.float16, torch.bfloat16, torch.float32]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment