Unverified Commit 42f39099 authored by kk's avatar kk Committed by GitHub
Browse files

Unify sglang coding style (#2856)


Co-authored-by: default avatarLin, Soga <soga.lin@amd.com>
parent 72c77763
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import os
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
...@@ -19,7 +18,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -19,7 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.utils import is_hip, permute_weight, set_weight_attrs from sglang.srt.utils import get_bool_env_var, is_hip, permute_weight, set_weight_attrs
if torch.cuda.is_available(): if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
...@@ -28,6 +27,8 @@ else: ...@@ -28,6 +27,8 @@ else:
import logging import logging
is_hip_ = is_hip()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -99,7 +100,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -99,7 +100,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))): if is_hip_ and get_bool_env_var("CK_MOE"):
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data), permute_weight(layer.w13_weight.data),
requires_grad=False, requires_grad=False,
...@@ -163,7 +164,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -163,7 +164,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
correction_bias=correction_bias, correction_bias=correction_bias,
) )
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))): if is_hip_ and get_bool_env_var("CK_MOE"):
import ater import ater
from ater.fused_moe import fused_experts_ck from ater.fused_moe import fused_experts_ck
......
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
import logging import logging
import os
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Dict, List, Optional
import torch import torch
...@@ -47,6 +46,8 @@ from sglang.srt.utils import ( ...@@ -47,6 +46,8 @@ from sglang.srt.utils import (
ACTIVATION_SCHEMES = ["static", "dynamic"] ACTIVATION_SCHEMES = ["static", "dynamic"]
is_hip_ = is_hip()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -162,7 +163,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -162,7 +163,7 @@ class Fp8LinearMethod(LinearMethodBase):
# kernel for fast weight-only FP8 quantization # kernel for fast weight-only FP8 quantization
self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN")
# Disable marlin for ROCm # Disable marlin for ROCm
if is_hip(): if is_hip_:
self.use_marlin = False self.use_marlin = False
self.block_quant = self.quant_config.weight_block_size is not None self.block_quant = self.quant_config.weight_block_size is not None
...@@ -274,7 +275,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -274,7 +275,7 @@ class Fp8LinearMethod(LinearMethodBase):
# Block quant doesn't need to process weights after loading # Block quant doesn't need to process weights after loading
if self.block_quant: if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz # If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip(): if is_hip_:
# activation_scheme: dynamic # activation_scheme: dynamic
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.weight, weight=layer.weight,
...@@ -331,7 +332,7 @@ class Fp8LinearMethod(LinearMethodBase): ...@@ -331,7 +332,7 @@ class Fp8LinearMethod(LinearMethodBase):
weight_scale = layer.weight_scale weight_scale = layer.weight_scale
# If ROCm, normalize the weights and scales to e4m3fnuz # If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip(): if is_hip_:
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
weight=weight, weight=weight,
weight_scale=weight_scale, weight_scale=weight_scale,
...@@ -568,7 +569,7 @@ class Fp8MoEMethod: ...@@ -568,7 +569,7 @@ class Fp8MoEMethod:
# Block quant doesn't need to process weights after loading # Block quant doesn't need to process weights after loading
if self.block_quant: if self.block_quant:
# If ROCm, normalize the weights and scales to e4m3fnuz # If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip(): if is_hip_:
# activation_scheme: dynamic # activation_scheme: dynamic
w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
weight=layer.w13_weight, weight=layer.w13_weight,
...@@ -595,7 +596,7 @@ class Fp8MoEMethod: ...@@ -595,7 +596,7 @@ class Fp8MoEMethod:
# If checkpoint is fp16 or bfloat16, quantize in place. # If checkpoint is fp16 or bfloat16, quantize in place.
if not self.quant_config.is_checkpoint_fp8_serialized: if not self.quant_config.is_checkpoint_fp8_serialized:
# If ROCm, use float8_e4m3fnuz instead (MI300x HW) # If ROCm, use float8_e4m3fnuz instead (MI300x HW)
fp8_dtype = torch.float8_e4m3fnuz if is_hip() else torch.float8_e4m3fn fp8_dtype = torch.float8_e4m3fnuz if is_hip_ else torch.float8_e4m3fn
w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype) w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype) w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
...@@ -617,8 +618,8 @@ class Fp8MoEMethod: ...@@ -617,8 +618,8 @@ class Fp8MoEMethod:
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
if is_hip(): if is_hip_:
if bool(int(os.getenv("CK_MOE", "0"))): if get_bool_env_var("CK_MOE"):
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data), permute_weight(layer.w13_weight.data),
requires_grad=False, requires_grad=False,
...@@ -629,7 +630,7 @@ class Fp8MoEMethod: ...@@ -629,7 +630,7 @@ class Fp8MoEMethod:
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif bool(int(os.getenv("MOE_PADDING", "0"))): elif get_bool_env_var("MOE_PADDING"):
# If ROCm, apply weight padding (min. Mem channel contention) only if set # If ROCm, apply weight padding (min. Mem channel contention) only if set
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
...@@ -671,7 +672,7 @@ class Fp8MoEMethod: ...@@ -671,7 +672,7 @@ class Fp8MoEMethod:
) )
# If ROCm, normalize the weights and scales to e4m3fnuz # If ROCm, normalize the weights and scales to e4m3fnuz
if is_hip(): if is_hip_:
# Normalize the weights and scales # Normalize the weights and scales
w13_weight, w13_weight_scale, w13_input_scale = ( w13_weight, w13_weight_scale, w13_input_scale = (
normalize_e4m3fn_to_e4m3fnuz( normalize_e4m3fn_to_e4m3fnuz(
...@@ -721,8 +722,8 @@ class Fp8MoEMethod: ...@@ -721,8 +722,8 @@ class Fp8MoEMethod:
max_w13_scales, requires_grad=False max_w13_scales, requires_grad=False
) )
if is_hip(): if is_hip_:
if bool(int(os.getenv("CK_MOE", "0"))): if get_bool_env_var("CK_MOE"):
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data), permute_weight(layer.w13_weight.data),
requires_grad=False, requires_grad=False,
...@@ -733,7 +734,7 @@ class Fp8MoEMethod: ...@@ -733,7 +734,7 @@ class Fp8MoEMethod:
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif bool(int(os.getenv("MOE_PADDING", "0"))): elif get_bool_env_var("MOE_PADDING"):
# If ROCm, apply weight padding (min. Mem channel contention) only if set # If ROCm, apply weight padding (min. Mem channel contention) only if set
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
...@@ -777,7 +778,7 @@ class Fp8MoEMethod: ...@@ -777,7 +778,7 @@ class Fp8MoEMethod:
correction_bias=correction_bias, correction_bias=correction_bias,
) )
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))): if is_hip_ and get_bool_env_var("CK_MOE"):
import ater import ater
from ater.fused_moe import fused_experts_ck from ater.fused_moe import fused_experts_ck
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment