Unverified Commit 3a6e0418 authored by HAI's avatar HAI Committed by GitHub
Browse files

[Feature, Hardware] Enable SGLang on AMD GPUs via PyTorch for ROCm (#1420)

parent 2fa5cec7
......@@ -13,6 +13,7 @@ limitations under the License.
"""Fused operators for activation layers."""
import logging
from typing import Optional
import torch
......@@ -28,6 +29,10 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__)
class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
......@@ -135,3 +140,10 @@ def get_act_fn(
act_fn, intermediate_size, input_is_parallel, params_dtype
)
return act_fn
if is_hip():
logger.info(
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
)
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
......@@ -12,22 +12,26 @@ from typing import TYPE_CHECKING
import torch
import torch.nn as nn
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from sglang.global_config import global_config
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.utils import is_hip
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
# ROCm: flashinfer available later
if not is_hip():
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
class AttentionBackend(ABC):
"""The base class of attention backends"""
......
......@@ -18,6 +18,8 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.utils import is_hip
logger = init_logger(__name__)
......@@ -381,6 +383,7 @@ from torch.nn import Module
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d,
normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize,
)
from vllm.utils import print_warning_once
......@@ -479,14 +482,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint is fp16, quantize in place.
# If checkpoint is fp16 or bfloat16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like(
layer.w13_weight.data, dtype=torch.float8_e4m3fn
)
w2_weight = torch.empty_like(
layer.w2_weight.data, dtype=torch.float8_e4m3fn
)
# If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
......@@ -534,6 +535,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.a2_scale.max(), requires_grad=False
)
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip():
# Normalize the weights and scales
w13_weight, w13_scale, a13_scale = normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_scale, layer.a13_scale
)
w2_weight, w2_scale, a2_scale = normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_scale, layer.a2_scale
)
# Reset the parameters
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False)
if a13_scale is not None:
layer.a13_scale = torch.nn.Parameter(a13_scale, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False)
if a2_scale is not None:
layer.a2_scale = torch.nn.Parameter(a2_scale, requires_grad=False)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
assert layer.w13_scale is not None
......
......@@ -15,6 +15,7 @@ limitations under the License.
"""Fused operators for normalization layers."""
import logging
from typing import Optional, Tuple, Union
import torch
......@@ -27,6 +28,10 @@ from flashinfer.norm import (
)
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__)
class RMSNorm(CustomOp):
def __init__(
......@@ -109,3 +114,10 @@ class GemmaRMSNorm(CustomOp):
return x, residual
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
if is_hip():
logger.info(
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
)
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
......@@ -2,17 +2,21 @@ import logging
from typing import Union
import torch
from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
)
from torch import nn
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import is_hip
# ROCm: flashinfer available later
if not is_hip():
from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
)
logger = logging.getLogger(__name__)
......
......@@ -21,12 +21,15 @@ import re
from dataclasses import dataclass
import torch
from flashinfer import SegmentGEMMWrapper
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import replace_submodule
from sglang.srt.utils import is_hip, replace_submodule
# ROCm: flashinfer available later
if not is_hip():
from flashinfer import SegmentGEMMWrapper
def get_stacked_name(name):
......
......@@ -19,7 +19,6 @@ limitations under the License.
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from flashinfer import bmm_fp8
from torch import nn
from transformers import PretrainedConfig
from vllm.config import CacheConfig
......@@ -48,6 +47,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.utils import is_hip
# ROCm: flashinfer available later
if not is_hip():
from flashinfer import bmm_fp8
class DeepseekV2MLP(nn.Module):
......
......@@ -19,7 +19,6 @@ import math
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from flashinfer import bmm_fp8
from torch import nn
from transformers import PretrainedConfig
from vllm.config import CacheConfig
......@@ -44,6 +43,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.utils import is_hip
# ROCm: flashinfer available later
if not is_hip():
from flashinfer import bmm_fp8
class MiniCPM3MLP(nn.Module):
......
......@@ -78,6 +78,7 @@ from sglang.srt.utils import (
assert_pkg_version,
configure_logger,
enable_show_time_cost,
is_hip,
kill_child_process,
maybe_set_triton_cache_manager,
prepare_model,
......@@ -434,6 +435,10 @@ def _set_envs_and_config(server_args: ServerArgs):
"at https://docs.flashinfer.ai/installation.html.",
)
if is_hip():
# to figure out a better method of not using fork later
mp.set_start_method("spawn", force=True)
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
headers = {}
......
......@@ -21,6 +21,8 @@ import logging
import random
from typing import List, Optional, Union
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__)
......@@ -164,6 +166,11 @@ class ServerArgs:
)
self.sampling_backend = "pytorch"
# ROCm: flashinfer available later
if is_hip():
self.attention_backend = "triton"
self.sampling_backend = "pytorch"
# Default kernel backends
if self.enable_mla:
logger.info("MLA optimization is tunred on. Use triton backend.")
......
......@@ -51,6 +51,11 @@ show_time_cost = False
time_infos = {}
# torch flag AMD GPU
def is_hip() -> bool:
return torch.version.hip is not None
def enable_show_time_cost():
global show_time_cost
show_time_cost = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment