"examples/vscode:/vscode.git/clone" did not exist on "209145a43e127b55c59152438eb421f16689166e"
Unverified Commit c555d794 authored by Zhaoyi Li's avatar Zhaoyi Li Committed by GitHub
Browse files

Minor update for ROCm variable style (#5562)

parent e2574ee9
......@@ -20,7 +20,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
)
from sglang.srt.utils import is_hip
_is_hip_ = is_hip()
_is_hip = is_hip()
class BenchmarkConfig(TypedDict):
......@@ -112,8 +112,8 @@ def benchmark_config(
)
if use_fp8_w8a8:
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn)
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
......@@ -204,7 +204,7 @@ def get_configs_compute_bound() -> List[Dict[str, int]]:
# TODO(woosuk): Increase the search space and use a performance model to
# prune the search space.
configs: List[BenchmarkConfig] = []
if _is_hip_:
if _is_hip:
configs = get_rocm_configs_compute_bound()
else:
for num_stages in [2, 3, 4, 5]:
......
......@@ -33,7 +33,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
from sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
is_hip_ = is_hip()
_is_hip = is_hip()
DTYPE_MAP = {
"float32": torch.float32,
......@@ -99,7 +99,7 @@ def w8a8_block_matmul(
if A.dtype == torch.float8_e4m3fnuz or A.dtype == torch.float8_e4m3fn:
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
if (is_hip_ == True and num_workgroups <= get_device_core_count())
if (_is_hip == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
)
else:
......@@ -157,7 +157,7 @@ def get_rocm_configs_compute_bound():
def get_configs_compute_bound():
configs = []
if is_hip_:
if _is_hip:
configs = get_rocm_configs_compute_bound()
else:
for num_stages in [2, 3, 4, 5]:
......@@ -244,7 +244,7 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type):
if input_type == "fp8":
fp8_info = torch.finfo(
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
)
fp8_max, fp8_min = fp8_info.max, fp8_info.min
......@@ -252,14 +252,14 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type):
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
)
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
)
B_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
)
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
)
else:
int8_info = torch.iinfo(torch.int8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment