"vscode:/vscode.git/clone" did not exist on "0db19da01f2322485e6e2fe84cec39869e0f35cc"
Unverified Commit c9bf3877 authored by Yichen Yan's avatar Yichen Yan Committed by GitHub
Browse files

Reduce overhead for fa by not calling heavy CUDA property check (#7375)

parent de2dd738
from typing import List, Optional, Tuple, Union from functools import lru_cache
from typing import Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -9,6 +10,7 @@ except: ...@@ -9,6 +10,7 @@ except:
raise ImportError("Can not import sgl_kernel. Please check your installation.") raise ImportError("Can not import sgl_kernel. Please check your installation.")
@lru_cache(maxsize=1)
def is_fa3_supported(device=None) -> bool: def is_fa3_supported(device=None) -> bool:
# There some fa3 FYI # There some fa3 FYI
# FA3 can fail without a enough shared memory for a some shapes, such as higher # FA3 can fail without a enough shared memory for a some shapes, such as higher
...@@ -18,10 +20,10 @@ def is_fa3_supported(device=None) -> bool: ...@@ -18,10 +20,10 @@ def is_fa3_supported(device=None) -> bool:
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
return ( return (torch.version.cuda >= "12.3") and (
torch.cuda.get_device_capability(device)[0] == 9 torch.cuda.get_device_capability(device)[0] == 9
or torch.cuda.get_device_capability(device)[0] == 8 or torch.cuda.get_device_capability(device)[0] == 8
) and (torch.version.cuda >= "12.3") )
def maybe_contiguous(x): def maybe_contiguous(x):
......
...@@ -25,10 +25,10 @@ def is_fa3_supported(device=None) -> bool: ...@@ -25,10 +25,10 @@ def is_fa3_supported(device=None) -> bool:
# https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x # https://docs.nvidia.com/cuda/cuda-c-programming-guide/#shared-memory-8-x
# And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a. # And for sgl-kernel right now, we can build fa3 on sm80/sm86/sm89/sm90a.
# That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3. # That means if you use A100/A*0/L20/L40/L40s/4090 you can use fa3.
return ( return (torch.version.cuda >= "12.3") and (
torch.cuda.get_device_capability(device)[0] == 9 torch.cuda.get_device_capability(device)[0] == 9
or torch.cuda.get_device_capability(device)[0] == 8 or torch.cuda.get_device_capability(device)[0] == 8
) and (torch.version.cuda >= "12.3") )
DISABLE_BACKWARD = True DISABLE_BACKWARD = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment