"docs/source/vscode:/vscode.git/clone" did not exist on "657c220de7c515f9dfede94b574a56d5685d7238"
Unverified Commit 5a144a8a authored by kk's avatar kk Committed by GitHub
Browse files

Fix run time error in ROCm platform (#5147)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
Co-authored-by: default avatarroot <root@dell300x-pla-t10-17.pla.dcgpu>
parent 27f8e6b9
...@@ -4,6 +4,10 @@ import torch ...@@ -4,6 +4,10 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.utils import is_hip
_is_hip = is_hip()
fused_softcap_autotune = triton.autotune( fused_softcap_autotune = triton.autotune(
configs=[ configs=[
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4), triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
...@@ -185,6 +189,9 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal ...@@ -185,6 +189,9 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
assert x.shape == residual.shape and x.dtype == residual.dtype assert x.shape == residual.shape and x.dtype == residual.dtype
output, mid = torch.empty_like(x), torch.empty_like(x) output, mid = torch.empty_like(x), torch.empty_like(x)
bs, hidden_dim = x.shape bs, hidden_dim = x.shape
min_num_warps = 16 if _is_hip else 32
if autotune: if autotune:
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)]( fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
...@@ -193,7 +200,10 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal ...@@ -193,7 +200,10 @@ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=Fal
config = { config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim), "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max( "num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4 min(
triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps
),
4,
), ),
} }
...@@ -250,10 +260,13 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False): ...@@ -250,10 +260,13 @@ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
else: else:
output = torch.empty_like(x) output = torch.empty_like(x)
bs, hidden_dim = x.shape bs, hidden_dim = x.shape
min_num_warps = 16 if _is_hip else 32
config = { config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim), "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max( "num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4 min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
), ),
} }
......
...@@ -5,6 +5,9 @@ import triton ...@@ -5,6 +5,9 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.moe.topk import fused_topk from sglang.srt.layers.moe.topk import fused_topk
from sglang.srt.utils import is_hip
_is_hip = is_hip()
@triton.jit @triton.jit
...@@ -116,10 +119,13 @@ def fused_moe_router_impl( ...@@ -116,10 +119,13 @@ def fused_moe_router_impl(
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device) topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
grid = lambda meta: (bs,) grid = lambda meta: (bs,)
min_num_warps = 16 if _is_hip else 32
config = { config = {
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim), "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
"num_warps": max( "num_warps": max(
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4 min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), min_num_warps), 4
), ),
} }
......
...@@ -171,6 +171,7 @@ def input_to_float8( ...@@ -171,6 +171,7 @@ def input_to_float8(
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
fp8_max = finfo.max fp8_max = finfo.max
if _is_hip: if _is_hip:
dtype = torch.float8_e4m3fnuz
fp8_max = 224.0 fp8_max = 224.0
scale = fp8_max / amax scale = fp8_max / amax
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max) x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment