Unverified Commit c555d794 authored by Zhaoyi Li's avatar Zhaoyi Li Committed by GitHub
Browse files

Minor update for ROCm variable style (#5562)

parent e2574ee9
...@@ -20,7 +20,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( ...@@ -20,7 +20,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
) )
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
_is_hip_ = is_hip() _is_hip = is_hip()
class BenchmarkConfig(TypedDict): class BenchmarkConfig(TypedDict):
...@@ -112,8 +112,8 @@ def benchmark_config( ...@@ -112,8 +112,8 @@ def benchmark_config(
) )
if use_fp8_w8a8: if use_fp8_w8a8:
w1 = w1.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn) w1 = w1.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
w2 = w2.to(torch.float8_e4m3fnuz if _is_hip_ else torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn)
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32) input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
...@@ -204,7 +204,7 @@ def get_configs_compute_bound() -> List[Dict[str, int]]: ...@@ -204,7 +204,7 @@ def get_configs_compute_bound() -> List[Dict[str, int]]:
# TODO(woosuk): Increase the search space and use a performance model to # TODO(woosuk): Increase the search space and use a performance model to
# prune the search space. # prune the search space.
configs: List[BenchmarkConfig] = [] configs: List[BenchmarkConfig] = []
if _is_hip_: if _is_hip:
configs = get_rocm_configs_compute_bound() configs = get_rocm_configs_compute_bound()
else: else:
for num_stages in [2, 3, 4, 5]: for num_stages in [2, 3, 4, 5]:
......
...@@ -33,7 +33,7 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ...@@ -33,7 +33,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
from sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul from sglang.srt.layers.quantization.int8_kernel import _w8a8_block_int8_matmul
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
is_hip_ = is_hip() _is_hip = is_hip()
DTYPE_MAP = { DTYPE_MAP = {
"float32": torch.float32, "float32": torch.float32,
...@@ -99,7 +99,7 @@ def w8a8_block_matmul( ...@@ -99,7 +99,7 @@ def w8a8_block_matmul(
if A.dtype == torch.float8_e4m3fnuz or A.dtype == torch.float8_e4m3fn: if A.dtype == torch.float8_e4m3fnuz or A.dtype == torch.float8_e4m3fn:
kernel = ( kernel = (
_w8a8_block_fp8_matmul_unrolledx4 _w8a8_block_fp8_matmul_unrolledx4
if (is_hip_ == True and num_workgroups <= get_device_core_count()) if (_is_hip == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul else _w8a8_block_fp8_matmul
) )
else: else:
...@@ -157,7 +157,7 @@ def get_rocm_configs_compute_bound(): ...@@ -157,7 +157,7 @@ def get_rocm_configs_compute_bound():
def get_configs_compute_bound(): def get_configs_compute_bound():
configs = [] configs = []
if is_hip_: if _is_hip:
configs = get_rocm_configs_compute_bound() configs = get_rocm_configs_compute_bound()
else: else:
for num_stages in [2, 3, 4, 5]: for num_stages in [2, 3, 4, 5]:
...@@ -244,7 +244,7 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type): ...@@ -244,7 +244,7 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type):
if input_type == "fp8": if input_type == "fp8":
fp8_info = torch.finfo( fp8_info = torch.finfo(
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
) )
fp8_max, fp8_min = fp8_info.max, fp8_info.min fp8_max, fp8_min = fp8_info.max, fp8_info.min
...@@ -252,14 +252,14 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type): ...@@ -252,14 +252,14 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type):
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
) )
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to( A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
) )
B_fp32 = ( B_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
) )
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to( B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(
torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
) )
else: else:
int8_info = torch.iinfo(torch.int8) int8_info = torch.iinfo(torch.int8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment