Commit d431f167 authored by Tri Dao's avatar Tri Dao
Browse files

Import torch before flash_attn_2_cuda

parent 0e8c46ae
import flash_attn_2_cuda as flash_attn_cuda
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import rearrange from einops import rearrange
# isort: off
# We need to import the CUDA kernels after importing torch
import flash_attn_2_cuda as flash_attn_cuda
# isort: on
def _get_block_size(device, head_dim, is_dropout, is_causal): def _get_block_size(device, head_dim, is_dropout, is_causal):
# This should match the block sizes in the CUDA kernel # This should match the block sizes in the CUDA kernel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment