Unverified Commit d1da58e2 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

unify is_cuda and is_hip (#4321)

parent 1cf63485
import torch
from torch import nn
_is_cuda = torch.cuda.is_available() and torch.version.cuda
_is_rocm = torch.cuda.is_available() and torch.version.hip
from sglang.srt.utils import is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
class CustomOp(nn.Module):
......@@ -34,7 +36,7 @@ class CustomOp(nn.Module):
def dispatch_forward(self):
if _is_cuda:
return self.forward_cuda
elif _is_rocm:
elif _is_hip:
return self.forward_hip
else:
return self.forward_native
......@@ -22,15 +22,16 @@ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
logger = logging.getLogger(__name__)
is_hip_ = is_hip()
_is_cuda = is_cuda()
_is_hip = is_hip()
if is_cuda():
if _is_cuda:
try:
import pynvml
except ImportError as e:
logger.warning("Failed to import pynvml with %r", e)
if is_hip_:
if _is_hip:
try:
from amdsmi import (
AmdSmiException,
......@@ -43,7 +44,7 @@ if is_hip_:
logger.warning("Failed to import amdsmi with %r", e)
try:
if ops.use_vllm_custom_allreduce and not is_hip_:
if ops.use_vllm_custom_allreduce and not _is_hip:
# Use vLLM custom allreduce
ops.meta_size()
else:
......@@ -63,7 +64,7 @@ _R = TypeVar("_R")
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
if is_hip_:
if _is_hip:
try:
amdsmi_init()
return fn(*args, **kwargs)
......@@ -81,7 +82,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@with_nvml_context
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
if is_hip_:
if _is_hip:
"""
query if the set of gpus are fully connected by xgmi (1 hop)
"""
......@@ -145,7 +146,7 @@ def is_weak_contiguous(inp: torch.Tensor):
class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
_MAX_CAR_SIZE = 8192 * 1024
if is_hip_:
if _is_hip:
# crossover is at 16MB buffer size for ROCm
_MAX_CAR_SIZE = 2 * 8192 * 1024
......@@ -229,7 +230,7 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
if is_cuda() or is_hip_:
if _is_cuda or _is_hip:
full_nvlink = is_full_nvlink(physical_device_ids, world_size)
if world_size > 2 and not full_nvlink:
......@@ -243,7 +244,7 @@ class CustomAllreduce:
# this is expensive to compute at the first time
# then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs
if not is_hip_ and not _can_p2p(rank, world_size):
if not _is_hip and not _can_p2p(rank, world_size):
logger.warning(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
......@@ -256,7 +257,7 @@ class CustomAllreduce:
self.world_size = world_size
self.full_nvlink = full_nvlink
if ops.use_vllm_custom_allreduce and not is_hip_:
if ops.use_vllm_custom_allreduce and not _is_hip:
# Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results.
......@@ -279,7 +280,7 @@ class CustomAllreduce:
)
ops.register_buffer(self._ptr, self.buffer_ptrs)
else:
if is_hip_:
if _is_hip:
# meta data buffers need to be "uncached" for signal on MI200
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
self.buffer = torch.empty(
......@@ -418,7 +419,7 @@ class CustomAllreduce:
ops.register_buffer(self._ptr, inp, handles, offsets)
def register_graph_buffers(self):
if is_hip_:
if _is_hip:
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
logger.info("Registering %d cuda graph addresses", len(offset))
......@@ -454,12 +455,12 @@ class CustomAllreduce:
return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if ops.use_vllm_custom_allreduce and not is_hip_:
if ops.use_vllm_custom_allreduce and not _is_hip:
if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size
return False
if is_hip_:
if _is_hip:
if self.full_nvlink:
if self.world_size == 8:
if self.MSCCL:
......@@ -532,7 +533,7 @@ class CustomAllreduce:
return None
if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing():
if is_hip_:
if _is_hip:
return self.all_reduce_reg(input)
else:
return self.all_reduce(input, registered=True)
......@@ -541,7 +542,7 @@ class CustomAllreduce:
# allreduce is out-of-place.
return torch.empty_like(input)
else:
if is_hip_:
if _is_hip:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
......@@ -556,7 +557,7 @@ class CustomAllreduce:
if ops.use_vllm_custom_allreduce:
self.free_shared_buffer(self.meta_ptrs)
self.free_shared_buffer(self.buffer_ptrs)
elif is_cuda():
elif _is_cuda:
self.free_shared_buffer(self.buffer_ptrs)
self.free_shared_buffer(self.tmp_result_buffer_ptrs)
self.free_shared_buffer(self.barrier_in_ptrs)
......
......@@ -27,7 +27,7 @@ import triton.language as tl
from sglang.srt.utils import is_hip
is_hip_ = is_hip()
_is_hip = is_hip()
logger = logging.getLogger(__name__)
......@@ -180,7 +180,7 @@ def _decode_att_m_fwd(
):
BLOCK = 64
# [TODO] work around SGPR limit on MI3xx
if is_hip_:
if _is_hip:
BLOCK = 8
NUM_KV_SPLITS = num_kv_splits
Lk = k_buffer.shape[-1]
......@@ -195,7 +195,7 @@ def _decode_att_m_fwd(
num_warps = 4
else:
num_warps = 2
if is_hip_:
if _is_hip:
num_warps = 1
BLOCK_DMODEL = triton.next_power_of_2(Lk)
......@@ -406,7 +406,7 @@ def _decode_grouped_att_m_fwd(
Lv = v_buffer.shape[-1]
# [TODO] work around shmem limit on MI3xx
if is_hip_ and Lk >= 576:
if _is_hip and Lk >= 576:
BLOCK = 16
if Lk == 576:
......@@ -433,7 +433,7 @@ def _decode_grouped_att_m_fwd(
extra_kargs = {}
num_stages = 2
if is_hip_:
if _is_hip:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
......@@ -546,7 +546,7 @@ def _decode_softmax_reducev_fwd(
NUM_KV_SPLITS = num_kv_splits
extra_kargs = {}
if is_hip_:
if _is_hip:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
......
......@@ -9,7 +9,7 @@ is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
CUDA_CAPABILITY = torch.cuda.get_device_capability()
is_hip_ = is_hip()
_is_hip = is_hip()
if global_server_args_dict.get("attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32
......@@ -1032,7 +1032,7 @@ def extend_attention_fwd(
BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv)
if is_hip_:
if _is_hip:
BLOCK_M, BLOCK_N = (64, 64)
num_warps = 4
......@@ -1062,7 +1062,7 @@ def extend_attention_fwd(
num_stages = 1
extra_kargs = {}
if is_hip_:
if _is_hip:
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
_fwd_kernel[grid](
......
......@@ -29,7 +29,7 @@ is_cuda_available = torch.cuda.is_available()
if is_cuda_available:
CUDA_CAPABILITY = torch.cuda.get_device_capability()
is_hip_ = is_hip()
_is_hip = is_hip()
@triton.jit
......@@ -330,7 +330,7 @@ def extend_attention_fwd(
BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv)
if is_hip_:
if _is_hip:
BLOCK_M, BLOCK_N = (64, 64)
num_warps = 4
......@@ -364,7 +364,7 @@ def extend_attention_fwd(
num_stages = 1
extra_kargs = {}
if is_hip_:
if _is_hip:
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
_fwd_kernel[grid](
......@@ -403,7 +403,7 @@ def extend_attention_fwd(
Lv=Lv,
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
STORE_TRANSPOSE=is_hip_,
STORE_TRANSPOSE=_is_hip,
num_warps=num_warps,
num_stages=num_stages,
**extra_kargs,
......
......@@ -32,7 +32,7 @@ def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip"
is_hip_ = is_hip()
_is_hip = is_hip()
@triton.jit
......@@ -333,7 +333,7 @@ def _decode_grouped_att_m_fwd_rope(
BLOCK = 32
# # [TODO] work around shmem limit on MI3xx
# if is_hip_ and kv_lora_rank >= 576:
# if _is_hip and kv_lora_rank >= 576:
# BLOCK = 16
qk_rope_head_dim = k_buffer.shape[-1] - kv_lora_rank
......@@ -353,7 +353,7 @@ def _decode_grouped_att_m_fwd_rope(
extra_kargs = {}
num_stages = 2
if is_hip_:
if _is_hip:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
......
......@@ -6,8 +6,9 @@ import triton
import triton.language as tl
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import is_cuda
_is_cuda = torch.cuda.is_available() and torch.version.cuda
_is_cuda = is_cuda()
if _is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8,
......
......@@ -30,6 +30,8 @@ from sglang.srt.utils import is_hip, set_weight_attrs
logger = logging.getLogger(__name__)
_is_hip = is_hip()
class GroupedGemmRunner(torch.nn.Module):
flashinfer_gemm_warpper = None
......@@ -703,7 +705,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
# If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
......
......@@ -23,10 +23,11 @@ from sglang.srt.utils import (
direct_register_custom_op,
get_bool_env_var,
get_device_name,
is_cuda,
is_hip,
)
is_hip_ = is_hip()
_is_hip = is_hip()
logger = logging.getLogger(__name__)
......@@ -36,8 +37,7 @@ enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
)
_is_cuda = torch.cuda.is_available() and torch.version.cuda
_is_rocm = torch.cuda.is_available() and torch.version.hip
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul
......@@ -46,7 +46,7 @@ if _is_cuda:
sglang_per_token_group_quant_fp8,
)
if _is_cuda or _is_rocm:
if _is_cuda or _is_hip:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
......@@ -679,7 +679,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2 if is_hip_ else 4,
"num_stages": 2 if _is_hip else 4,
}
if M <= E:
config = {
......@@ -688,7 +688,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2 if is_hip_ else 4,
"num_stages": 2 if _is_hip else 4,
}
else:
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
......@@ -698,7 +698,7 @@ def get_default_config(
"BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2 if is_hip_ else 3,
"num_stages": 2 if _is_hip else 3,
}
else:
config = {
......@@ -976,7 +976,7 @@ def fused_experts_impl(
if (
not (use_fp8_w8a8 or use_int8_w8a8)
or block_shape is not None
or (is_hip_ and get_bool_env_var("CK_MOE"))
or (_is_hip and get_bool_env_var("CK_MOE"))
):
padded_size = 0
......@@ -1131,7 +1131,7 @@ def fused_experts_impl(
if no_combine:
pass
elif is_hip_:
elif _is_hip:
ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
......
......@@ -27,9 +27,9 @@ else:
import logging
is_hip_ = is_hip()
_is_hip = is_hip()
if is_hip_:
if _is_hip:
from aiter import ck_moe
logger = logging.getLogger(__name__)
......@@ -102,7 +102,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if is_hip_ and get_bool_env_var("CK_MOE"):
if _is_hip and get_bool_env_var("CK_MOE"):
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
......@@ -175,7 +175,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
correction_bias=correction_bias,
)
if is_hip_ and get_bool_env_var("CK_MOE"):
if _is_hip and get_bool_env_var("CK_MOE"):
assert not no_combine, "unsupported"
return ck_moe(
x,
......@@ -514,7 +514,7 @@ class FusedMoE(torch.nn.Module):
# Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0
# this is needed for compressed-tensors only
......@@ -556,7 +556,7 @@ class FusedMoE(torch.nn.Module):
quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 0.5
self._load_per_channel_weight_scale(
......@@ -579,7 +579,7 @@ class FusedMoE(torch.nn.Module):
)
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0
self._load_per_tensor_weight_scale(
......
......@@ -54,9 +54,9 @@ from sglang.srt.utils import (
ACTIVATION_SCHEMES = ["static", "dynamic"]
is_hip_ = is_hip()
_is_hip = is_hip()
if is_hip_:
if _is_hip:
from aiter.fused_moe_bf16_asm import asm_moe
from aiter.ops.shuffle import shuffle_weight
......@@ -175,7 +175,7 @@ class Fp8LinearMethod(LinearMethodBase):
# kernel for fast weight-only FP8 quantization
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
# Disable marlin for ROCm
if is_hip_:
if _is_hip:
self.use_marlin = False
self.block_quant = self.quant_config.weight_block_size is not None
......@@ -287,7 +287,7 @@ class Fp8LinearMethod(LinearMethodBase):
# Block quant doesn't need to process weights after loading
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip_:
if _is_hip:
# activation_scheme: dynamic
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight,
......@@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight
weight_scale = layer.weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip_:
if _is_hip:
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight,
weight_scale=weight_scale,
......@@ -563,7 +563,7 @@ class Fp8MoEMethod:
layer.register_parameter("w2_weight_scale", w2_weight_scale)
if (
is_hip_
_is_hip
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter(
......@@ -630,7 +630,7 @@ class Fp8MoEMethod:
# Block quant doesn't need to process weights after loading
if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip_:
if _is_hip:
# activation_scheme: dynamic
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w13_weight,
......@@ -667,7 +667,7 @@ class Fp8MoEMethod:
# If checkpoint is fp16 or bfloat16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
......@@ -689,7 +689,7 @@ class Fp8MoEMethod:
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
if is_hip_:
if _is_hip:
self.process_weights_hip_scale_padding(layer)
return
......@@ -721,7 +721,7 @@ class Fp8MoEMethod:
)
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip_:
if _is_hip:
# Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz(
......@@ -771,7 +771,7 @@ class Fp8MoEMethod:
max_w13_scales, requires_grad=False
)
if is_hip_:
if _is_hip:
self.process_weights_hip_scale_padding(layer)
return
......@@ -882,7 +882,7 @@ class Fp8MoEMethod:
correction_bias=correction_bias,
)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"):
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert not no_combine, f"{no_combine=} is not supported."
return asm_moe(
......@@ -895,7 +895,7 @@ class Fp8MoEMethod:
layer.w2_weight_scale1,
activation=activation,
)
if is_hip_ and get_bool_env_var("CK_MOE"):
if _is_hip and get_bool_env_var("CK_MOE"):
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
......
......@@ -22,12 +22,12 @@ import torch
import triton
import triton.language as tl
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip
from sglang.srt.utils import get_device_core_count, get_device_name, is_cuda, is_hip
is_hip_ = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
_is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
_is_cuda = torch.cuda.is_available() and torch.version.cuda
_is_cuda = is_cuda()
if _is_cuda:
import deep_gemm
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
......@@ -157,7 +157,7 @@ def per_token_group_quant_fp8(
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if is_hip_:
if _is_hip:
fp8_max = 224.0
fp8_min = -fp8_max
......@@ -332,7 +332,7 @@ def static_quant_fp8(
finfo = torch.finfo(dtype)
fp8_max = finfo.max
if is_hip_:
if _is_hip:
fp8_max = 224.0
fp8_min = -fp8_max
......@@ -732,7 +732,7 @@ def w8a8_block_fp8_matmul(
else:
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
if (is_hip_ == True and num_workgroups <= get_device_core_count())
if (_is_hip == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
)
......
......@@ -17,8 +17,8 @@ from sglang.srt.utils import (
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
is_hip_ = is_hip()
if is_hip_ and get_bool_env_var("CK_MOE"):
_is_hip = is_hip()
if _is_hip and get_bool_env_var("CK_MOE"):
from aiter import gemm_a8w8_blockscale
_is_cuda = is_cuda()
......@@ -111,7 +111,7 @@ def apply_w8a8_block_fp8_linear(
output = fp8_blockwise_scaled_mm(
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
)
elif is_hip_ and get_bool_env_var("CK_MOE"):
elif _is_hip and get_bool_env_var("CK_MOE"):
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
......@@ -142,7 +142,7 @@ def input_to_float8(
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
fp8_max = finfo.max
if is_hip_:
if _is_hip:
fp8_max = 224.0
scale = fp8_max / amax
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
......
......@@ -16,6 +16,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
)
from sglang.srt.utils import is_hip
_is_hip = is_hip()
class W8A8Fp8Config(QuantizationConfig):
"""Config class for W8A8 FP8 Quantization.
......@@ -71,7 +73,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = layer.weight
weight_scale = layer.weight_scale.detach()
if is_hip():
if _is_hip:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale
)
......
......@@ -35,7 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import (
)
from sglang.srt.utils import is_hip
is_hip_ = is_hip()
_is_hip = is_hip()
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
......@@ -119,7 +119,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
else:
capture_bs = list(range(1, 33))
if is_hip_:
if _is_hip:
capture_bs += [i * 8 for i in range(21, 33)]
if max(capture_bs) > model_runner.req_to_token_pool.size:
......
......@@ -40,7 +40,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.utils import add_prefix, is_hip
is_hip_ = is_hip()
_is_hip = is_hip()
class DeepseekModelNextN(nn.Module):
......@@ -277,7 +277,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
if is_hip_:
if _is_hip:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=w,
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
......@@ -301,7 +301,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
and self_attn.w_scale is None
):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
if is_hip_:
if _is_hip:
self_attn.w_scale *= 2.0
......
......@@ -65,7 +65,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
is_hip_ = is_hip()
_is_hip = is_hip()
if is_cuda_available():
from sgl_kernel import bmm_fp8
......@@ -571,7 +571,7 @@ class DeepseekV2AttentionMLA(nn.Module):
if no_absorb():
return self.forward_normal(positions, hidden_states, forward_batch)
else:
if is_hip_:
if _is_hip:
if (
os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
and forward_batch.forward_mode.is_decode()
......@@ -1190,7 +1190,7 @@ class DeepseekV2ForCausalLM(nn.Module):
weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
if is_hip_:
if _is_hip:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=w,
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
......@@ -1230,7 +1230,7 @@ class DeepseekV2ForCausalLM(nn.Module):
and self_attn.w_scale is None
):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
if is_hip_:
if _is_hip:
self_attn.w_scale *= 2.0
def get_embed_and_head(self):
......
......@@ -72,13 +72,17 @@ show_time_cost = False
time_infos = {}
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
def is_hip() -> bool:
"""Return whether it is HIP on the AMD ROCm platform."""
return torch.version.hip is not None
def is_rocm() -> bool:
return torch.cuda.is_available() and torch.version.hip
def is_cuda():
return hasattr(torch, "cuda") and torch.version.cuda is not None
return torch.cuda.is_available() and torch.version.cuda
def is_cuda_alike():
......@@ -100,11 +104,11 @@ def is_flashinfer_available():
"""
if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
return False
return torch.cuda.is_available() and torch.version.cuda
return is_cuda()
def is_cuda_available():
return torch.cuda.is_available() and torch.version.cuda
return is_cuda()
def enable_show_time_cost():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment