"tests/python/vscode:/vscode.git/clone" did not exist on "4bc256b11faa83fe16675b69bd25c3432a754151"
Unverified Commit 6cdcbcc6 authored by JieXin Liang's avatar JieXin Liang Committed by GitHub
Browse files

[fix] fix enable_pdl for blackwell (#9011)

parent c480a3f6
......@@ -2,7 +2,7 @@ from dataclasses import dataclass
from typing import Optional
import torch
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
from sgl_kernel.utils import get_cuda_stream, is_arch_support_pdl
# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
......@@ -41,7 +41,7 @@ def rmsnorm(
if out is None:
out = torch.empty_like(input)
if enable_pdl is None:
enable_pdl = is_hopper_arch()
enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
return out
......@@ -77,7 +77,7 @@ def fused_add_rmsnorm(
If None, will be automatically enabled on Hopper architecture.
"""
if enable_pdl is None:
enable_pdl = is_hopper_arch()
enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.fused_add_rmsnorm.default(
input, residual, weight, eps, enable_pdl
)
......@@ -117,7 +117,7 @@ def gemma_rmsnorm(
if out is None:
out = torch.empty_like(input)
if enable_pdl is None:
enable_pdl = is_hopper_arch()
enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
return out
......@@ -153,7 +153,7 @@ def gemma_fused_add_rmsnorm(
If None, will be automatically enabled on Hopper architecture.
"""
if enable_pdl is None:
enable_pdl = is_hopper_arch()
enable_pdl = is_arch_support_pdl()
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
input, residual, weight, eps, enable_pdl
)
......
......@@ -43,8 +43,8 @@ def _to_tensor_scalar_tuple(x):
@functools.lru_cache(maxsize=1)
def is_hopper_arch() -> bool:
def is_arch_support_pdl() -> bool:
# Hopper arch's compute capability == 9.0
device = torch.cuda.current_device()
major, minor = torch.cuda.get_device_capability(device)
return major == 9
return major >= 9
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment