Unverified Commit c9bf3877 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

Reduce overhead for fa by not calling heavy CUDA property check (#7375)

parent de2dd738
from typing import List, Optional, Tuple, Union from functools import lru_cache
from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -9,6 +10,7 @@ except: ...@@ -9,6 +10,7 @@ except:
raise ImportError("Can not import sgl_kernel. Please check your installation.") raise ImportError("Can not import sgl_kernel. Please check your installation.")
@lru_cache(maxsize=1)
def is_fa3_supported(device=None) -> bool: def is_fa3_supported(device=None) -> bool:
# There some fa3 FYI # There some fa3 FYI
# FA3 can fail without a enough shared memory for a some shapes, such as higher # FA3 can fail without a enough shared memory for a some shapes, such as higher
...@@ -18,10 +20,10 @@ def is_fa3_supported(device=None) -> bool: ...@@ -18,10 +20,10 @@ def is_fa3_supported(device=None) -> bool:
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
return ( return (torch.version.cuda >= "12.3") and (
torch.cuda.get_device_capability(device)[0] == 9 torch.cuda.get_device_capability(device)[0] == 9
or torch.cuda.get_device_capability(device)[0] == 8 or torch.cuda.get_device_capability(device)[0] == 8
) and (torch.version.cuda >= "12.3") )
def maybe_contiguous(x): def maybe_contiguous(x):
......
...@@ -25,10 +25,10 @@ def is_fa3_supported(device=None) -> bool: ...@@ -25,10 +25,10 @@ def is_fa3_supported(device=None) -> bool:
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
return ( return (torch.version.cuda >= "12.3") and (
torch.cuda.get_device_capability(device)[0] == 9 torch.cuda.get_device_capability(device)[0] == 9
or torch.cuda.get_device_capability(device)[0] == 8 or torch.cuda.get_device_capability(device)[0] == 8
) and (torch.version.cuda >= "12.3") )
DISABLE_BACKWARD = True DISABLE_BACKWARD = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment