Unverified Commit d1da58e2 authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

unify is_cuda and is_hip (#4321)

parent 1cf63485
import torch import torch
from torch import nn from torch import nn
_is_cuda = torch.cuda.is_available() and torch.version.cuda from sglang.srt.utils import is_cuda, is_hip
_is_rocm = torch.cuda.is_available() and torch.version.hip
_is_cuda = is_cuda()
_is_hip = is_hip()
class CustomOp(nn.Module): class CustomOp(nn.Module):
...@@ -34,7 +36,7 @@ class CustomOp(nn.Module): ...@@ -34,7 +36,7 @@ class CustomOp(nn.Module):
def dispatch_forward(self): def dispatch_forward(self):
if _is_cuda: if _is_cuda:
return self.forward_cuda return self.forward_cuda
elif _is_rocm: elif _is_hip:
return self.forward_hip return self.forward_hip
else: else:
return self.forward_native return self.forward_native
...@@ -22,15 +22,16 @@ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip ...@@ -22,15 +22,16 @@ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
is_hip_ = is_hip() _is_cuda = is_cuda()
_is_hip = is_hip()
if is_cuda(): if _is_cuda:
try: try:
import pynvml import pynvml
except ImportError as e: except ImportError as e:
logger.warning("Failed to import pynvml with %r", e) logger.warning("Failed to import pynvml with %r", e)
if is_hip_: if _is_hip:
try: try:
from amdsmi import ( from amdsmi import (
AmdSmiException, AmdSmiException,
...@@ -43,7 +44,7 @@ if is_hip_: ...@@ -43,7 +44,7 @@ if is_hip_:
logger.warning("Failed to import amdsmi with %r", e) logger.warning("Failed to import amdsmi with %r", e)
try: try:
if ops.use_vllm_custom_allreduce and not is_hip_: if ops.use_vllm_custom_allreduce and not _is_hip:
# Use vLLM custom allreduce # Use vLLM custom allreduce
ops.meta_size() ops.meta_size()
else: else:
...@@ -63,7 +64,7 @@ _R = TypeVar("_R") ...@@ -63,7 +64,7 @@ _R = TypeVar("_R")
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn) @wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R: def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
if is_hip_: if _is_hip:
try: try:
amdsmi_init() amdsmi_init()
return fn(*args, **kwargs) return fn(*args, **kwargs)
...@@ -81,7 +82,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]: ...@@ -81,7 +82,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@with_nvml_context @with_nvml_context
def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool: def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
if is_hip_: if _is_hip:
""" """
query if the set of gpus are fully connected by xgmi (1 hop) query if the set of gpus are fully connected by xgmi (1 hop)
""" """
...@@ -145,7 +146,7 @@ def is_weak_contiguous(inp: torch.Tensor): ...@@ -145,7 +146,7 @@ def is_weak_contiguous(inp: torch.Tensor):
class CustomAllreduce: class CustomAllreduce:
_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
_MAX_CAR_SIZE = 8192 * 1024 _MAX_CAR_SIZE = 8192 * 1024
if is_hip_: if _is_hip:
# crossover is at 16MB buffer size for ROCm # crossover is at 16MB buffer size for ROCm
_MAX_CAR_SIZE = 2 * 8192 * 1024 _MAX_CAR_SIZE = 2 * 8192 * 1024
...@@ -229,7 +230,7 @@ class CustomAllreduce: ...@@ -229,7 +230,7 @@ class CustomAllreduce:
# test nvlink first, this will filter out most of the cases # test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported # where custom allreduce is not supported
# this checks hardware and driver support for NVLink # this checks hardware and driver support for NVLink
if is_cuda() or is_hip_: if _is_cuda or _is_hip:
full_nvlink = is_full_nvlink(physical_device_ids, world_size) full_nvlink = is_full_nvlink(physical_device_ids, world_size)
if world_size > 2 and not full_nvlink: if world_size > 2 and not full_nvlink:
...@@ -243,7 +244,7 @@ class CustomAllreduce: ...@@ -243,7 +244,7 @@ class CustomAllreduce:
# this is expensive to compute at the first time # this is expensive to compute at the first time
# then we cache the result # then we cache the result
# On AMD GPU, p2p is always enabled between XGMI connected GPUs # On AMD GPU, p2p is always enabled between XGMI connected GPUs
if not is_hip_ and not _can_p2p(rank, world_size): if not _is_hip and not _can_p2p(rank, world_size):
logger.warning( logger.warning(
"Custom allreduce is disabled because your platform lacks " "Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this " "GPU P2P capability or P2P test failed. To silence this "
...@@ -256,7 +257,7 @@ class CustomAllreduce: ...@@ -256,7 +257,7 @@ class CustomAllreduce:
self.world_size = world_size self.world_size = world_size
self.full_nvlink = full_nvlink self.full_nvlink = full_nvlink
if ops.use_vllm_custom_allreduce and not is_hip_: if ops.use_vllm_custom_allreduce and not _is_hip:
# Buffers memory are owned by this Python class and passed to C++. # Buffers memory are owned by this Python class and passed to C++.
# Meta data composes of two parts: meta data for synchronization and a # Meta data composes of two parts: meta data for synchronization and a
# temporary buffer for storing intermediate allreduce results. # temporary buffer for storing intermediate allreduce results.
...@@ -279,7 +280,7 @@ class CustomAllreduce: ...@@ -279,7 +280,7 @@ class CustomAllreduce:
) )
ops.register_buffer(self._ptr, self.buffer_ptrs) ops.register_buffer(self._ptr, self.buffer_ptrs)
else: else:
if is_hip_: if _is_hip:
# meta data buffers need to be "uncached" for signal on MI200 # meta data buffers need to be "uncached" for signal on MI200
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size) self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
self.buffer = torch.empty( self.buffer = torch.empty(
...@@ -418,7 +419,7 @@ class CustomAllreduce: ...@@ -418,7 +419,7 @@ class CustomAllreduce:
ops.register_buffer(self._ptr, inp, handles, offsets) ops.register_buffer(self._ptr, inp, handles, offsets)
def register_graph_buffers(self): def register_graph_buffers(self):
if is_hip_: if _is_hip:
handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
logger.info("Registering %d cuda graph addresses", len(offset)) logger.info("Registering %d cuda graph addresses", len(offset))
...@@ -454,12 +455,12 @@ class CustomAllreduce: ...@@ -454,12 +455,12 @@ class CustomAllreduce:
return False return False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides # for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL. # little performance improvement over NCCL.
if ops.use_vllm_custom_allreduce and not is_hip_: if ops.use_vllm_custom_allreduce and not _is_hip:
if self.world_size == 2 or self.full_nvlink: if self.world_size == 2 or self.full_nvlink:
return inp_size < self.max_size return inp_size < self.max_size
return False return False
if is_hip_: if _is_hip:
if self.full_nvlink: if self.full_nvlink:
if self.world_size == 8: if self.world_size == 8:
if self.MSCCL: if self.MSCCL:
...@@ -532,7 +533,7 @@ class CustomAllreduce: ...@@ -532,7 +533,7 @@ class CustomAllreduce:
return None return None
if self._IS_CAPTURING: if self._IS_CAPTURING:
if torch.cuda.is_current_stream_capturing(): if torch.cuda.is_current_stream_capturing():
if is_hip_: if _is_hip:
return self.all_reduce_reg(input) return self.all_reduce_reg(input)
else: else:
return self.all_reduce(input, registered=True) return self.all_reduce(input, registered=True)
...@@ -541,7 +542,7 @@ class CustomAllreduce: ...@@ -541,7 +542,7 @@ class CustomAllreduce:
# allreduce is out-of-place. # allreduce is out-of-place.
return torch.empty_like(input) return torch.empty_like(input)
else: else:
if is_hip_: if _is_hip:
# note: outside of cuda graph context, # note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should # custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance # be small(<=1% of overall latency) compared to the performance
...@@ -556,7 +557,7 @@ class CustomAllreduce: ...@@ -556,7 +557,7 @@ class CustomAllreduce:
if ops.use_vllm_custom_allreduce: if ops.use_vllm_custom_allreduce:
self.free_shared_buffer(self.meta_ptrs) self.free_shared_buffer(self.meta_ptrs)
self.free_shared_buffer(self.buffer_ptrs) self.free_shared_buffer(self.buffer_ptrs)
elif is_cuda(): elif _is_cuda:
self.free_shared_buffer(self.buffer_ptrs) self.free_shared_buffer(self.buffer_ptrs)
self.free_shared_buffer(self.tmp_result_buffer_ptrs) self.free_shared_buffer(self.tmp_result_buffer_ptrs)
self.free_shared_buffer(self.barrier_in_ptrs) self.free_shared_buffer(self.barrier_in_ptrs)
......
...@@ -27,7 +27,7 @@ import triton.language as tl ...@@ -27,7 +27,7 @@ import triton.language as tl
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
is_hip_ = is_hip() _is_hip = is_hip()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -180,7 +180,7 @@ def _decode_att_m_fwd( ...@@ -180,7 +180,7 @@ def _decode_att_m_fwd(
): ):
BLOCK = 64 BLOCK = 64
# [TODO] work around SGPR limit on MI3xx # [TODO] work around SGPR limit on MI3xx
if is_hip_: if _is_hip:
BLOCK = 8 BLOCK = 8
NUM_KV_SPLITS = num_kv_splits NUM_KV_SPLITS = num_kv_splits
Lk = k_buffer.shape[-1] Lk = k_buffer.shape[-1]
...@@ -195,7 +195,7 @@ def _decode_att_m_fwd( ...@@ -195,7 +195,7 @@ def _decode_att_m_fwd(
num_warps = 4 num_warps = 4
else: else:
num_warps = 2 num_warps = 2
if is_hip_: if _is_hip:
num_warps = 1 num_warps = 1
BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DMODEL = triton.next_power_of_2(Lk)
...@@ -406,7 +406,7 @@ def _decode_grouped_att_m_fwd( ...@@ -406,7 +406,7 @@ def _decode_grouped_att_m_fwd(
Lv = v_buffer.shape[-1] Lv = v_buffer.shape[-1]
# [TODO] work around shmem limit on MI3xx # [TODO] work around shmem limit on MI3xx
if is_hip_ and Lk >= 576: if _is_hip and Lk >= 576:
BLOCK = 16 BLOCK = 16
if Lk == 576: if Lk == 576:
...@@ -433,7 +433,7 @@ def _decode_grouped_att_m_fwd( ...@@ -433,7 +433,7 @@ def _decode_grouped_att_m_fwd(
extra_kargs = {} extra_kargs = {}
num_stages = 2 num_stages = 2
if is_hip_: if _is_hip:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
...@@ -546,7 +546,7 @@ def _decode_softmax_reducev_fwd( ...@@ -546,7 +546,7 @@ def _decode_softmax_reducev_fwd(
NUM_KV_SPLITS = num_kv_splits NUM_KV_SPLITS = num_kv_splits
extra_kargs = {} extra_kargs = {}
if is_hip_: if _is_hip:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
......
...@@ -9,7 +9,7 @@ is_cuda_available = torch.cuda.is_available() ...@@ -9,7 +9,7 @@ is_cuda_available = torch.cuda.is_available()
if is_cuda_available: if is_cuda_available:
CUDA_CAPABILITY = torch.cuda.get_device_capability() CUDA_CAPABILITY = torch.cuda.get_device_capability()
is_hip_ = is_hip() _is_hip = is_hip()
if global_server_args_dict.get("attention_reduce_in_fp32", False): if global_server_args_dict.get("attention_reduce_in_fp32", False):
REDUCE_TRITON_TYPE = tl.float32 REDUCE_TRITON_TYPE = tl.float32
...@@ -1032,7 +1032,7 @@ def extend_attention_fwd( ...@@ -1032,7 +1032,7 @@ def extend_attention_fwd(
BLOCK_DPE = 0 BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv) BLOCK_DV = triton.next_power_of_2(Lv)
if is_hip_: if _is_hip:
BLOCK_M, BLOCK_N = (64, 64) BLOCK_M, BLOCK_N = (64, 64)
num_warps = 4 num_warps = 4
...@@ -1062,7 +1062,7 @@ def extend_attention_fwd( ...@@ -1062,7 +1062,7 @@ def extend_attention_fwd(
num_stages = 1 num_stages = 1
extra_kargs = {} extra_kargs = {}
if is_hip_: if _is_hip:
extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2}
_fwd_kernel[grid]( _fwd_kernel[grid](
......
...@@ -29,7 +29,7 @@ is_cuda_available = torch.cuda.is_available() ...@@ -29,7 +29,7 @@ is_cuda_available = torch.cuda.is_available()
if is_cuda_available: if is_cuda_available:
CUDA_CAPABILITY = torch.cuda.get_device_capability() CUDA_CAPABILITY = torch.cuda.get_device_capability()
is_hip_ = is_hip() _is_hip = is_hip()
@triton.jit @triton.jit
...@@ -330,7 +330,7 @@ def extend_attention_fwd( ...@@ -330,7 +330,7 @@ def extend_attention_fwd(
BLOCK_DPE = 0 BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv) BLOCK_DV = triton.next_power_of_2(Lv)
if is_hip_: if _is_hip:
BLOCK_M, BLOCK_N = (64, 64) BLOCK_M, BLOCK_N = (64, 64)
num_warps = 4 num_warps = 4
...@@ -364,7 +364,7 @@ def extend_attention_fwd( ...@@ -364,7 +364,7 @@ def extend_attention_fwd(
num_stages = 1 num_stages = 1
extra_kargs = {} extra_kargs = {}
if is_hip_: if _is_hip:
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
_fwd_kernel[grid]( _fwd_kernel[grid](
...@@ -403,7 +403,7 @@ def extend_attention_fwd( ...@@ -403,7 +403,7 @@ def extend_attention_fwd(
Lv=Lv, Lv=Lv,
USE_CUSTOM_MASK=USE_CUSTOM_MASK, USE_CUSTOM_MASK=USE_CUSTOM_MASK,
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK, SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
STORE_TRANSPOSE=is_hip_, STORE_TRANSPOSE=_is_hip,
num_warps=num_warps, num_warps=num_warps,
num_stages=num_stages, num_stages=num_stages,
**extra_kargs, **extra_kargs,
......
...@@ -32,7 +32,7 @@ def is_hip(): ...@@ -32,7 +32,7 @@ def is_hip():
return triton.runtime.driver.active.get_current_target().backend == "hip" return triton.runtime.driver.active.get_current_target().backend == "hip"
is_hip_ = is_hip() _is_hip = is_hip()
@triton.jit @triton.jit
...@@ -333,7 +333,7 @@ def _decode_grouped_att_m_fwd_rope( ...@@ -333,7 +333,7 @@ def _decode_grouped_att_m_fwd_rope(
BLOCK = 32 BLOCK = 32
# # [TODO] work around shmem limit on MI3xx # # [TODO] work around shmem limit on MI3xx
# if is_hip_ and kv_lora_rank >= 576: # if _is_hip and kv_lora_rank >= 576:
# BLOCK = 16 # BLOCK = 16
qk_rope_head_dim = k_buffer.shape[-1] - kv_lora_rank qk_rope_head_dim = k_buffer.shape[-1] - kv_lora_rank
...@@ -353,7 +353,7 @@ def _decode_grouped_att_m_fwd_rope( ...@@ -353,7 +353,7 @@ def _decode_grouped_att_m_fwd_rope(
extra_kargs = {} extra_kargs = {}
num_stages = 2 num_stages = 2
if is_hip_: if _is_hip:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2}
......
...@@ -6,8 +6,9 @@ import triton ...@@ -6,8 +6,9 @@ import triton
import triton.language as tl import triton.language as tl
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
from sglang.srt.utils import is_cuda
_is_cuda = torch.cuda.is_available() and torch.version.cuda _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.layers.quantization.fp8_kernel import (
sglang_per_token_group_quant_fp8, sglang_per_token_group_quant_fp8,
......
...@@ -30,6 +30,8 @@ from sglang.srt.utils import is_hip, set_weight_attrs ...@@ -30,6 +30,8 @@ from sglang.srt.utils import is_hip, set_weight_attrs
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_is_hip = is_hip()
class GroupedGemmRunner(torch.nn.Module): class GroupedGemmRunner(torch.nn.Module):
flashinfer_gemm_warpper = None flashinfer_gemm_warpper = None
...@@ -703,7 +705,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod): ...@@ -703,7 +705,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
# If checkpoint is fp16, quantize in place. # If checkpoint is fp16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
# If rocm, use float8_e4m3fnuz as dtype # If rocm, use float8_e4m3fnuz as dtype
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
......
...@@ -23,10 +23,11 @@ from sglang.srt.utils import ( ...@@ -23,10 +23,11 @@ from sglang.srt.utils import (
direct_register_custom_op, direct_register_custom_op,
get_bool_env_var, get_bool_env_var,
get_device_name, get_device_name,
is_cuda,
is_hip, is_hip,
) )
is_hip_ = is_hip() _is_hip = is_hip()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -36,8 +37,7 @@ enable_moe_align_block_size_triton = bool( ...@@ -36,8 +37,7 @@ enable_moe_align_block_size_triton = bool(
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0")) int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
) )
_is_cuda = torch.cuda.is_available() and torch.version.cuda _is_cuda = is_cuda()
_is_rocm = torch.cuda.is_available() and torch.version.hip
if _is_cuda: if _is_cuda:
from sgl_kernel import gelu_and_mul, silu_and_mul from sgl_kernel import gelu_and_mul, silu_and_mul
...@@ -46,7 +46,7 @@ if _is_cuda: ...@@ -46,7 +46,7 @@ if _is_cuda:
sglang_per_token_group_quant_fp8, sglang_per_token_group_quant_fp8,
) )
if _is_cuda or _is_rocm: if _is_cuda or _is_hip:
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
...@@ -679,7 +679,7 @@ def get_default_config( ...@@ -679,7 +679,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32, "GROUP_SIZE_M": 32,
"num_warps": 8, "num_warps": 8,
"num_stages": 2 if is_hip_ else 4, "num_stages": 2 if _is_hip else 4,
} }
if M <= E: if M <= E:
config = { config = {
...@@ -688,7 +688,7 @@ def get_default_config( ...@@ -688,7 +688,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128, "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 2 if is_hip_ else 4, "num_stages": 2 if _is_hip else 4,
} }
else: else:
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1] # Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
...@@ -698,7 +698,7 @@ def get_default_config( ...@@ -698,7 +698,7 @@ def get_default_config(
"BLOCK_SIZE_K": block_shape[1], "BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32, "GROUP_SIZE_M": 32,
"num_warps": 4, "num_warps": 4,
"num_stages": 2 if is_hip_ else 3, "num_stages": 2 if _is_hip else 3,
} }
else: else:
config = { config = {
...@@ -976,7 +976,7 @@ def fused_experts_impl( ...@@ -976,7 +976,7 @@ def fused_experts_impl(
if ( if (
not (use_fp8_w8a8 or use_int8_w8a8) not (use_fp8_w8a8 or use_int8_w8a8)
or block_shape is not None or block_shape is not None
or (is_hip_ and get_bool_env_var("CK_MOE")) or (_is_hip and get_bool_env_var("CK_MOE"))
): ):
padded_size = 0 padded_size = 0
...@@ -1131,7 +1131,7 @@ def fused_experts_impl( ...@@ -1131,7 +1131,7 @@ def fused_experts_impl(
if no_combine: if no_combine:
pass pass
elif is_hip_: elif _is_hip:
ops.moe_sum( ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape), intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx], out_hidden_states[begin_chunk_idx:end_chunk_idx],
......
...@@ -27,9 +27,9 @@ else: ...@@ -27,9 +27,9 @@ else:
import logging import logging
is_hip_ = is_hip() _is_hip = is_hip()
if is_hip_: if _is_hip:
from aiter import ck_moe from aiter import ck_moe
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -102,7 +102,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -102,7 +102,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if is_hip_ and get_bool_env_var("CK_MOE"): if _is_hip and get_bool_env_var("CK_MOE"):
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data), permute_weight(layer.w13_weight.data),
requires_grad=False, requires_grad=False,
...@@ -175,7 +175,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -175,7 +175,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
correction_bias=correction_bias, correction_bias=correction_bias,
) )
if is_hip_ and get_bool_env_var("CK_MOE"): if _is_hip and get_bool_env_var("CK_MOE"):
assert not no_combine, "unsupported" assert not no_combine, "unsupported"
return ck_moe( return ck_moe(
x, x,
...@@ -514,7 +514,7 @@ class FusedMoE(torch.nn.Module): ...@@ -514,7 +514,7 @@ class FusedMoE(torch.nn.Module):
# Case input scale: input_scale loading is only supported for fp8 # Case input scale: input_scale loading is only supported for fp8
if "input_scale" in weight_name: if "input_scale" in weight_name:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD) # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0 loaded_weight = loaded_weight * 2.0
# this is needed for compressed-tensors only # this is needed for compressed-tensors only
...@@ -556,7 +556,7 @@ class FusedMoE(torch.nn.Module): ...@@ -556,7 +556,7 @@ class FusedMoE(torch.nn.Module):
quant_method = getattr(param, "quant_method", None) quant_method = getattr(param, "quant_method", None)
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD) # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 0.5 loaded_weight = loaded_weight * 0.5
self._load_per_channel_weight_scale( self._load_per_channel_weight_scale(
...@@ -579,7 +579,7 @@ class FusedMoE(torch.nn.Module): ...@@ -579,7 +579,7 @@ class FusedMoE(torch.nn.Module):
) )
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD) # INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
loaded_weight = loaded_weight * 2.0 loaded_weight = loaded_weight * 2.0
self._load_per_tensor_weight_scale( self._load_per_tensor_weight_scale(
......
...@@ -54,9 +54,9 @@ from sglang.srt.utils import ( ...@@ -54,9 +54,9 @@ from sglang.srt.utils import (
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
is_hip_ = is_hip() _is_hip = is_hip()
if is_hip_: if _is_hip:
from aiter.fused_moe_bf16_asm import asm_moe from aiter.fused_moe_bf16_asm import asm_moe
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
...@@ -175,7 +175,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -175,7 +175,7 @@ class Fp8LinearMethod(LinearMethodBase):
# kernel for fast weight-only FP8 quantization # kernel for fast weight-only FP8 quantization
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
# Disable marlin for ROCm # Disable marlin for ROCm
if is_hip_: if _is_hip:
self.use_marlin = False self.use_marlin = False
self.block_quant = self.quant_config.weight_block_size is not None self.block_quant = self.quant_config.weight_block_size is not None
...@@ -287,7 +287,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -287,7 +287,7 @@ class Fp8LinearMethod(LinearMethodBase):
# Block quant doesn't need to process weights after loading # Block quant doesn't need to process weights after loading
if self.block_quant: if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz # If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip_: if _is_hip:
# activation_scheme: dynamic # activation_scheme: dynamic
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight, weight=layer.weight,
...@@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -347,7 +347,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight = layer.weight weight = layer.weight
weight_scale = layer.weight_scale weight_scale = layer.weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz # If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip_: if _is_hip:
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
weight_scale=weight_scale, weight_scale=weight_scale,
...@@ -563,7 +563,7 @@ class Fp8MoEMethod: ...@@ -563,7 +563,7 @@ class Fp8MoEMethod:
layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale)
if ( if (
is_hip_ _is_hip
): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel ): # and get_bool_env_var("CK_MOE"): TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter( w13_weight_scale1 = torch.nn.Parameter(
...@@ -630,7 +630,7 @@ class Fp8MoEMethod: ...@@ -630,7 +630,7 @@ class Fp8MoEMethod:
# Block quant doesn't need to process weights after loading # Block quant doesn't need to process weights after loading
if self.block_quant: if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz # If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip_: if _is_hip:
# activation_scheme: dynamic # activation_scheme: dynamic
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w13_weight, weight=layer.w13_weight,
...@@ -667,7 +667,7 @@ class Fp8MoEMethod: ...@@ -667,7 +667,7 @@ class Fp8MoEMethod:
# If checkpoint is fp16 or bfloat16, quantize in place. # If checkpoint is fp16 or bfloat16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW) # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
...@@ -689,7 +689,7 @@ class Fp8MoEMethod: ...@@ -689,7 +689,7 @@ class Fp8MoEMethod:
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
if is_hip_: if _is_hip:
self.process_weights_hip_scale_padding(layer) self.process_weights_hip_scale_padding(layer)
return return
...@@ -721,7 +721,7 @@ class Fp8MoEMethod: ...@@ -721,7 +721,7 @@ class Fp8MoEMethod:
) )
# If ROCm, normalize the weights and scales to e4m3fnuz # If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip_: if _is_hip:
# Normalize the weights and scales # Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = ( w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
...@@ -771,7 +771,7 @@ class Fp8MoEMethod: ...@@ -771,7 +771,7 @@ class Fp8MoEMethod:
max_w13_scales, requires_grad=False max_w13_scales, requires_grad=False
) )
if is_hip_: if _is_hip:
self.process_weights_hip_scale_padding(layer) self.process_weights_hip_scale_padding(layer)
return return
...@@ -882,7 +882,7 @@ class Fp8MoEMethod: ...@@ -882,7 +882,7 @@ class Fp8MoEMethod:
correction_bias=correction_bias, correction_bias=correction_bias,
) )
if is_hip_ and get_bool_env_var("USE_INT4_WEIGHT"): if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
# TODO: add triton kernel and add check get_bool_env_var("CK_MOE") # TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
assert not no_combine, f"{no_combine=} is not supported." assert not no_combine, f"{no_combine=} is not supported."
return asm_moe( return asm_moe(
...@@ -895,7 +895,7 @@ class Fp8MoEMethod: ...@@ -895,7 +895,7 @@ class Fp8MoEMethod:
layer.w2_weight_scale1, layer.w2_weight_scale1,
activation=activation, activation=activation,
) )
if is_hip_ and get_bool_env_var("CK_MOE"): if _is_hip and get_bool_env_var("CK_MOE"):
# TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being. # TODO(CK_MOE): FP8 or FP8 block_quant only supports 'silu' for the time-being.
assert ( assert (
activation == "silu" activation == "silu"
......
...@@ -22,12 +22,12 @@ import torch ...@@ -22,12 +22,12 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from sglang.srt.utils import get_device_core_count, get_device_name, is_hip from sglang.srt.utils import get_device_core_count, get_device_name, is_cuda, is_hip
is_hip_ = is_hip() _is_hip = is_hip()
fp8_type_ = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
_is_cuda = torch.cuda.is_available() and torch.version.cuda _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
import deep_gemm import deep_gemm
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8 from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
...@@ -157,7 +157,7 @@ def per_token_group_quant_fp8( ...@@ -157,7 +157,7 @@ def per_token_group_quant_fp8(
finfo = torch.finfo(dtype) finfo = torch.finfo(dtype)
fp8_max = finfo.max fp8_max = finfo.max
if is_hip_: if _is_hip:
fp8_max = 224.0 fp8_max = 224.0
fp8_min = -fp8_max fp8_min = -fp8_max
...@@ -332,7 +332,7 @@ def static_quant_fp8( ...@@ -332,7 +332,7 @@ def static_quant_fp8(
finfo = torch.finfo(dtype) finfo = torch.finfo(dtype)
fp8_max = finfo.max fp8_max = finfo.max
if is_hip_: if _is_hip:
fp8_max = 224.0 fp8_max = 224.0
fp8_min = -fp8_max fp8_min = -fp8_max
...@@ -732,7 +732,7 @@ def w8a8_block_fp8_matmul( ...@@ -732,7 +732,7 @@ def w8a8_block_fp8_matmul(
else: else:
kernel = ( kernel = (
_w8a8_block_fp8_matmul_unrolledx4 _w8a8_block_fp8_matmul_unrolledx4
if (is_hip_ == True and num_workgroups <= get_device_core_count()) if (_is_hip == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul else _w8a8_block_fp8_matmul
) )
......
...@@ -17,8 +17,8 @@ from sglang.srt.utils import ( ...@@ -17,8 +17,8 @@ from sglang.srt.utils import (
use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL") use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_KERNEL")
is_hip_ = is_hip() _is_hip = is_hip()
if is_hip_ and get_bool_env_var("CK_MOE"): if _is_hip and get_bool_env_var("CK_MOE"):
from aiter import gemm_a8w8_blockscale from aiter import gemm_a8w8_blockscale
_is_cuda = is_cuda() _is_cuda = is_cuda()
...@@ -111,7 +111,7 @@ def apply_w8a8_block_fp8_linear( ...@@ -111,7 +111,7 @@ def apply_w8a8_block_fp8_linear(
output = fp8_blockwise_scaled_mm( output = fp8_blockwise_scaled_mm(
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
) )
elif is_hip_ and get_bool_env_var("CK_MOE"): elif _is_hip and get_bool_env_var("CK_MOE"):
q_input, x_scale = per_token_group_quant_fp8( q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False input_2d, block_size[1], column_major_scales=False
) )
...@@ -142,7 +142,7 @@ def input_to_float8( ...@@ -142,7 +142,7 @@ def input_to_float8(
min_val, max_val = x.aminmax() min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
fp8_max = finfo.max fp8_max = finfo.max
if is_hip_: if _is_hip:
fp8_max = 224.0 fp8_max = 224.0
scale = fp8_max / amax scale = fp8_max / amax
x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max) x_scl_sat = (x * scale).clamp(min=-fp8_max, max=fp8_max)
......
...@@ -16,6 +16,8 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -16,6 +16,8 @@ from sglang.srt.layers.quantization.fp8_utils import (
) )
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
_is_hip = is_hip()
class W8A8Fp8Config(QuantizationConfig): class W8A8Fp8Config(QuantizationConfig):
"""Config class for W8A8 FP8 Quantization. """Config class for W8A8 FP8 Quantization.
...@@ -71,7 +73,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase): ...@@ -71,7 +73,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
weight = layer.weight weight = layer.weight
weight_scale = layer.weight_scale.detach() weight_scale = layer.weight_scale.detach()
if is_hip(): if _is_hip:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight_scale=weight_scale weight=weight, weight_scale=weight_scale
) )
......
...@@ -35,7 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ...@@ -35,7 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import (
) )
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
is_hip_ = is_hip() _is_hip = is_hip()
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
...@@ -119,7 +119,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner): ...@@ -119,7 +119,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
else: else:
capture_bs = list(range(1, 33)) capture_bs = list(range(1, 33))
if is_hip_: if _is_hip:
capture_bs += [i * 8 for i in range(21, 33)] capture_bs += [i * 8 for i in range(21, 33)]
if max(capture_bs) > model_runner.req_to_token_pool.size: if max(capture_bs) > model_runner.req_to_token_pool.size:
......
...@@ -40,7 +40,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader ...@@ -40,7 +40,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
from sglang.srt.utils import add_prefix, is_hip from sglang.srt.utils import add_prefix, is_hip
is_hip_ = is_hip() _is_hip = is_hip()
class DeepseekModelNextN(nn.Module): class DeepseekModelNextN(nn.Module):
...@@ -277,7 +277,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -277,7 +277,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
weight_block_size = self.quant_config.weight_block_size weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None: if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
if is_hip_: if _is_hip:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=w, weight=w,
weight_scale=self_attn.kv_b_proj.weight_scale_inv, weight_scale=self_attn.kv_b_proj.weight_scale_inv,
...@@ -301,7 +301,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -301,7 +301,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
and self_attn.w_scale is None and self_attn.w_scale is None
): ):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale self_attn.w_scale = self_attn.kv_b_proj.weight_scale
if is_hip_: if _is_hip:
self_attn.w_scale *= 2.0 self_attn.w_scale *= 2.0
......
...@@ -65,7 +65,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch ...@@ -65,7 +65,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda_available, is_hip from sglang.srt.utils import add_prefix, is_cuda_available, is_hip
is_hip_ = is_hip() _is_hip = is_hip()
if is_cuda_available(): if is_cuda_available():
from sgl_kernel import bmm_fp8 from sgl_kernel import bmm_fp8
...@@ -571,7 +571,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -571,7 +571,7 @@ class DeepseekV2AttentionMLA(nn.Module):
if no_absorb(): if no_absorb():
return self.forward_normal(positions, hidden_states, forward_batch) return self.forward_normal(positions, hidden_states, forward_batch)
else: else:
if is_hip_: if _is_hip:
if ( if (
os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
and forward_batch.forward_mode.is_decode() and forward_batch.forward_mode.is_decode()
...@@ -1190,7 +1190,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1190,7 +1190,7 @@ class DeepseekV2ForCausalLM(nn.Module):
weight_block_size = self.quant_config.weight_block_size weight_block_size = self.quant_config.weight_block_size
if weight_block_size is not None: if weight_block_size is not None:
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
if is_hip_: if _is_hip:
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=w, weight=w,
weight_scale=self_attn.kv_b_proj.weight_scale_inv, weight_scale=self_attn.kv_b_proj.weight_scale_inv,
...@@ -1230,7 +1230,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1230,7 +1230,7 @@ class DeepseekV2ForCausalLM(nn.Module):
and self_attn.w_scale is None and self_attn.w_scale is None
): ):
self_attn.w_scale = self_attn.kv_b_proj.weight_scale self_attn.w_scale = self_attn.kv_b_proj.weight_scale
if is_hip_: if _is_hip:
self_attn.w_scale *= 2.0 self_attn.w_scale *= 2.0
def get_embed_and_head(self): def get_embed_and_head(self):
......
...@@ -72,13 +72,17 @@ show_time_cost = False ...@@ -72,13 +72,17 @@ show_time_cost = False
time_infos = {} time_infos = {}
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
def is_hip() -> bool: def is_hip() -> bool:
"""Return whether it is HIP on the AMD ROCm platform."""
return torch.version.hip is not None return torch.version.hip is not None
def is_rocm() -> bool:
return torch.cuda.is_available() and torch.version.hip
def is_cuda(): def is_cuda():
return hasattr(torch, "cuda") and torch.version.cuda is not None return torch.cuda.is_available() and torch.version.cuda
def is_cuda_alike(): def is_cuda_alike():
...@@ -100,11 +104,11 @@ def is_flashinfer_available(): ...@@ -100,11 +104,11 @@ def is_flashinfer_available():
""" """
if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"): if not get_bool_env_var("SGLANG_IS_FLASHINFER_AVAILABLE", default="true"):
return False return False
return torch.cuda.is_available() and torch.version.cuda return is_cuda()
def is_cuda_available(): def is_cuda_available():
return torch.cuda.is_available() and torch.version.cuda return is_cuda()
def enable_show_time_cost(): def enable_show_time_cost():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment