"examples/pytorch/vscode:/vscode.git/clone" did not exist on "cb2327d4a93c042f3c6fe1c42fe8f4c31f087e3b"
Unverified Commit 3a6e0418 authored by HAI's avatar HAI Committed by GitHub
Browse files

[Feature, Hardware] Enable SGLang on AMD GPUs via PyTorch for ROCm (#1420)

parent 2fa5cec7
...@@ -13,6 +13,7 @@ limitations under the License. ...@@ -13,6 +13,7 @@ limitations under the License.
"""Fused operators for activation layers.""" """Fused operators for activation layers."""
import logging
from typing import Optional from typing import Optional
import torch import torch
...@@ -28,6 +29,10 @@ from vllm.model_executor.custom_op import CustomOp ...@@ -28,6 +29,10 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__)
class SiluAndMul(CustomOp): class SiluAndMul(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor: def forward_native(self, x: torch.Tensor) -> torch.Tensor:
...@@ -135,3 +140,10 @@ def get_act_fn( ...@@ -135,3 +140,10 @@ def get_act_fn(
act_fn, intermediate_size, input_is_parallel, params_dtype act_fn, intermediate_size, input_is_parallel, params_dtype
) )
return act_fn return act_fn
if is_hip():
logger.info(
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
)
from vllm.model_executor.layers.activation import GeluAndMul, SiluAndMul
...@@ -12,22 +12,26 @@ from typing import TYPE_CHECKING ...@@ -12,22 +12,26 @@ from typing import TYPE_CHECKING
import torch import torch
import torch.nn as nn import torch.nn as nn
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from sglang.global_config import global_config from sglang.global_config import global_config
from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.utils import is_hip
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
# ROCm: flashinfer available later
if not is_hip():
from flashinfer import (
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.cascade import merge_state
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
class AttentionBackend(ABC): class AttentionBackend(ABC):
"""The base class of attention backends""" """The base class of attention backends"""
......
...@@ -18,6 +18,8 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -18,6 +18,8 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.utils import is_hip
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -381,6 +383,7 @@ from torch.nn import Module ...@@ -381,6 +383,7 @@ from torch.nn import Module
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, all_close_1d,
normalize_e4m3fn_to_e4m3fnuz,
per_tensor_dequantize, per_tensor_dequantize,
) )
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
...@@ -479,14 +482,12 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -479,14 +482,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# If checkpoint is fp16, quantize in place. # If checkpoint is fp16 or bfloat16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
w13_weight = torch.empty_like( # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
layer.w13_weight.data, dtype=torch.float8_e4m3fn fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn
) w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like( w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
layer.w2_weight.data, dtype=torch.float8_e4m3fn
)
# Re-initialize w13_scale because we directly quantize # Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor. # merged w13 weights and generate a single scaling factor.
...@@ -534,6 +535,25 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -534,6 +535,25 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.a2_scale.max(), requires_grad=False layer.a2_scale.max(), requires_grad=False
) )
# If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip():
# Normalize the weights and scales
w13_weight, w13_scale, a13_scale = normalize_e4m3fn_to_e4m3fnuz(
layer.w13_weight, layer.w13_scale, layer.a13_scale
)
w2_weight, w2_scale, a2_scale = normalize_e4m3fn_to_e4m3fnuz(
layer.w2_weight, layer.w2_scale, layer.a2_scale
)
# Reset the parameters
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w13_scale = torch.nn.Parameter(w13_scale, requires_grad=False)
if a13_scale is not None:
layer.a13_scale = torch.nn.Parameter(a13_scale, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
layer.w2_scale = torch.nn.Parameter(w2_scale, requires_grad=False)
if a2_scale is not None:
layer.a2_scale = torch.nn.Parameter(a2_scale, requires_grad=False)
# Fp8 moe kernel needs single weight scale for w13 per expert. # Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert. # We take the max then dequant and requant each expert.
assert layer.w13_scale is not None assert layer.w13_scale is not None
......
...@@ -15,6 +15,7 @@ limitations under the License. ...@@ -15,6 +15,7 @@ limitations under the License.
"""Fused operators for normalization layers.""" """Fused operators for normalization layers."""
import logging
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
...@@ -27,6 +28,10 @@ from flashinfer.norm import ( ...@@ -27,6 +28,10 @@ from flashinfer.norm import (
) )
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__)
class RMSNorm(CustomOp): class RMSNorm(CustomOp):
def __init__( def __init__(
...@@ -109,3 +114,10 @@ class GemmaRMSNorm(CustomOp): ...@@ -109,3 +114,10 @@ class GemmaRMSNorm(CustomOp):
return x, residual return x, residual
out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon) out = gemma_rmsnorm(x, self.weight.data, self.variance_epsilon)
return out return out
if is_hip():
logger.info(
"FlashInfer is not available on AMD GPUs. Fallback to other kernel libraries."
)
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
...@@ -2,17 +2,21 @@ import logging ...@@ -2,17 +2,21 @@ import logging
from typing import Union from typing import Union
import torch import torch
from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
)
from torch import nn from torch import nn
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import is_hip
# ROCm: flashinfer available later
if not is_hip():
from flashinfer.sampling import (
min_p_sampling_from_probs,
top_k_renorm_prob,
top_k_top_p_sampling_from_probs,
top_p_renorm_prob,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
...@@ -21,12 +21,15 @@ import re ...@@ -21,12 +21,15 @@ import re
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
from flashinfer import SegmentGEMMWrapper
from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
from sglang.srt.lora.lora_config import LoRAConfig from sglang.srt.lora.lora_config import LoRAConfig
from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.utils import replace_submodule from sglang.srt.utils import is_hip, replace_submodule
# ROCm: flashinfer available later
if not is_hip():
from flashinfer import SegmentGEMMWrapper
def get_stacked_name(name): def get_stacked_name(name):
......
...@@ -19,7 +19,6 @@ limitations under the License. ...@@ -19,7 +19,6 @@ limitations under the License.
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from flashinfer import bmm_fp8
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import CacheConfig from vllm.config import CacheConfig
...@@ -48,6 +47,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -48,6 +47,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.utils import is_hip
# ROCm: flashinfer available later
if not is_hip():
from flashinfer import bmm_fp8
class DeepseekV2MLP(nn.Module): class DeepseekV2MLP(nn.Module):
......
...@@ -19,7 +19,6 @@ import math ...@@ -19,7 +19,6 @@ import math
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from flashinfer import bmm_fp8
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.config import CacheConfig from vllm.config import CacheConfig
...@@ -44,6 +43,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -44,6 +43,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.utils import is_hip
# ROCm: flashinfer available later
if not is_hip():
from flashinfer import bmm_fp8
class MiniCPM3MLP(nn.Module): class MiniCPM3MLP(nn.Module):
......
...@@ -78,6 +78,7 @@ from sglang.srt.utils import ( ...@@ -78,6 +78,7 @@ from sglang.srt.utils import (
assert_pkg_version, assert_pkg_version,
configure_logger, configure_logger,
enable_show_time_cost, enable_show_time_cost,
is_hip,
kill_child_process, kill_child_process,
maybe_set_triton_cache_manager, maybe_set_triton_cache_manager,
prepare_model, prepare_model,
...@@ -434,6 +435,10 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -434,6 +435,10 @@ def _set_envs_and_config(server_args: ServerArgs):
"at https://docs.flashinfer.ai/installation.html.", "at https://docs.flashinfer.ai/installation.html.",
) )
if is_hip():
# to figure out a better method of not using fork later
mp.set_start_method("spawn", force=True)
def _wait_and_warmup(server_args, pipe_finish_writer, pid): def _wait_and_warmup(server_args, pipe_finish_writer, pid):
headers = {} headers = {}
......
...@@ -21,6 +21,8 @@ import logging ...@@ -21,6 +21,8 @@ import logging
import random import random
from typing import List, Optional, Union from typing import List, Optional, Union
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -164,6 +166,11 @@ class ServerArgs: ...@@ -164,6 +166,11 @@ class ServerArgs:
) )
self.sampling_backend = "pytorch" self.sampling_backend = "pytorch"
# ROCm: flashinfer available later
if is_hip():
self.attention_backend = "triton"
self.sampling_backend = "pytorch"
# Default kernel backends # Default kernel backends
if self.enable_mla: if self.enable_mla:
logger.info("MLA optimization is tunred on. Use triton backend.") logger.info("MLA optimization is tunred on. Use triton backend.")
......
...@@ -51,6 +51,11 @@ show_time_cost = False ...@@ -51,6 +51,11 @@ show_time_cost = False
time_infos = {} time_infos = {}
# torch flag AMD GPU
def is_hip() -> bool:
return torch.version.hip is not None
def enable_show_time_cost(): def enable_show_time_cost():
global show_time_cost global show_time_cost
show_time_cost = True show_time_cost = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment