"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "c5ea35dc19c02bea1c3e3c2b5c47a3d0ad81f9ec"
Unverified Commit 6cdcbcc6 authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[fix] fix enable_pdl for blackwell (#9011)

parent c480a3f6
...@@ -2,7 +2,7 @@ from dataclasses import dataclass ...@@ -2,7 +2,7 @@ from dataclasses import dataclass
from typing import Optional from typing import Optional
import torch import torch
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch from sgl_kernel.utils import get_cuda_stream, is_arch_support_pdl
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer # These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
...@@ -41,7 +41,7 @@ def rmsnorm( ...@@ -41,7 +41,7 @@ def rmsnorm(
if out is None: if out is None:
out = torch.empty_like(input) out = torch.empty_like(input)
if enable_pdl is None: if enable_pdl is None:
enable_pdl = is_hopper_arch() enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl) torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
return out return out
...@@ -77,7 +77,7 @@ def fused_add_rmsnorm( ...@@ -77,7 +77,7 @@ def fused_add_rmsnorm(
If None, will be automatically enabled on Hopper architecture. If None, will be automatically enabled on Hopper architecture.
""" """
if enable_pdl is None: if enable_pdl is None:
enable_pdl = is_hopper_arch() enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.fused_add_rmsnorm.default( torch.ops.sgl_kernel.fused_add_rmsnorm.default(
input, residual, weight, eps, enable_pdl input, residual, weight, eps, enable_pdl
) )
...@@ -117,7 +117,7 @@ def gemma_rmsnorm( ...@@ -117,7 +117,7 @@ def gemma_rmsnorm(
if out is None: if out is None:
out = torch.empty_like(input) out = torch.empty_like(input)
if enable_pdl is None: if enable_pdl is None:
enable_pdl = is_hopper_arch() enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl) torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
return out return out
...@@ -153,7 +153,7 @@ def gemma_fused_add_rmsnorm( ...@@ -153,7 +153,7 @@ def gemma_fused_add_rmsnorm(
If None, will be automatically enabled on Hopper architecture. If None, will be automatically enabled on Hopper architecture.
""" """
if enable_pdl is None: if enable_pdl is None:
enable_pdl = is_hopper_arch() enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default( torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
input, residual, weight, eps, enable_pdl input, residual, weight, eps, enable_pdl
) )
......
...@@ -43,8 +43,8 @@ def _to_tensor_scalar_tuple(x): ...@@ -43,8 +43,8 @@ def _to_tensor_scalar_tuple(x):
@functools.lru_cache(maxsize=1) @functools.lru_cache(maxsize=1)
def is_hopper_arch() -> bool: def is_arch_support_pdl() -> bool:
# Hopper arch's compute capability == 9.0 # Hopper arch's compute capability == 9.0
device = torch.cuda.current_device() device = torch.cuda.current_device()
major, minor = torch.cuda.get_device_capability(device) major, minor = torch.cuda.get_device_capability(device)
return major == 9 return major >= 9
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment