Unverified Commit 177320a5 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Clean up imports (#5467)

parent d7bc19a4
...@@ -48,7 +48,7 @@ _is_cuda = is_cuda() ...@@ -48,7 +48,7 @@ _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import awq_dequantize from sgl_kernel import awq_dequantize
else: else:
from vllm import _custom_ops as ops from vllm._custom_ops import awq_dequantize
class DeepseekModelNextN(nn.Module): class DeepseekModelNextN(nn.Module):
...@@ -273,7 +273,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM): ...@@ -273,7 +273,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
self_attn.kv_b_proj.qzeros, self_attn.kv_b_proj.qzeros,
).T ).T
else: else:
w = ops.awq_dequantize( w = awq_dequantize(
self_attn.kv_b_proj.qweight, self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales, self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros, self_attn.kv_b_proj.qzeros,
......
...@@ -51,6 +51,7 @@ from sglang.srt.layers.linear import ( ...@@ -51,6 +51,7 @@ from sglang.srt.layers.linear import (
) )
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
...@@ -80,10 +81,8 @@ _is_cuda = is_cuda() ...@@ -80,10 +81,8 @@ _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
else: else:
from vllm import _custom_ops as ops from vllm._custom_ops import awq_dequantize
if _is_hip: if _is_hip:
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import ( from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
...@@ -861,7 +860,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -861,7 +860,7 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
elif self.w_kc.dtype == torch.float8_e4m3fn: elif self.w_kc.dtype == torch.float8_e4m3fn:
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8( q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn q_nope.transpose(0, 1),
) )
q_nope_out = bmm_fp8( q_nope_out = bmm_fp8(
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16 q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
...@@ -892,7 +891,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -892,7 +891,7 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
elif self.w_vc.dtype == torch.float8_e4m3fn: elif self.w_vc.dtype == torch.float8_e4m3fn:
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8( attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn attn_output.transpose(0, 1),
) )
attn_bmm_output = bmm_fp8( attn_bmm_output = bmm_fp8(
attn_output_val, attn_output_val,
...@@ -1565,7 +1564,7 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -1565,7 +1564,7 @@ class DeepseekV2ForCausalLM(nn.Module):
self_attn.kv_b_proj.qzeros, self_attn.kv_b_proj.qzeros,
).T ).T
else: else:
w = ops.awq_dequantize( w = awq_dequantize(
self_attn.kv_b_proj.qweight, self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales, self_attn.kv_b_proj.scales,
self_attn.kv_b_proj.qzeros, self_attn.kv_b_proj.qzeros,
......
import re
from typing import Dict, Tuple from typing import Dict, Tuple
......
...@@ -10,12 +10,11 @@ import torch ...@@ -10,12 +10,11 @@ import torch
import sglang.srt.sampling.penaltylib as penaltylib import sglang.srt.sampling.penaltylib as penaltylib
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
logger = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
logger = logging.getLogger(__name__)
@dataclasses.dataclass @dataclasses.dataclass
class SamplingBatchInfo: class SamplingBatchInfo:
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# Some shortcuts for backward compatibility.
# They will be removed in new versions.
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.entrypoints.http_server import kill_process_tree, launch_server
...@@ -187,6 +187,7 @@ class ServerArgs: ...@@ -187,6 +187,7 @@ class ServerArgs:
n_share_experts_fusion: int = 0 n_share_experts_fusion: int = 0
disable_shared_experts_fusion: bool = False disable_shared_experts_fusion: bool = False
disable_chunked_prefix_cache: bool = False disable_chunked_prefix_cache: bool = False
disable_fast_image_processor: bool = False
# Debug tensor dumps # Debug tensor dumps
debug_tensor_dump_output_folder: Optional[str] = None debug_tensor_dump_output_folder: Optional[str] = None
...@@ -198,9 +199,6 @@ class ServerArgs: ...@@ -198,9 +199,6 @@ class ServerArgs:
disaggregation_bootstrap_port: int = 8998 disaggregation_bootstrap_port: int = 8998
disaggregation_transfer_backend: str = "mooncake" disaggregation_transfer_backend: str = "mooncake"
# multimodal
disable_fast_image_processor: bool = False
def __post_init__(self): def __post_init__(self):
# Expert parallelism # Expert parallelism
if self.enable_ep_moe: if self.enable_ep_moe:
...@@ -1136,6 +1134,11 @@ class ServerArgs: ...@@ -1136,6 +1134,11 @@ class ServerArgs:
action="store_true", action="store_true",
help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.", help="Disable chunked prefix cache feature for deepseek, which should save overhead for short sequences.",
) )
parser.add_argument(
"--disable-fast-image-processor",
action="store_true",
help="Adopt base image processor instead of fast image processor.",
)
# Server warmups # Server warmups
parser.add_argument( parser.add_argument(
...@@ -1187,13 +1190,6 @@ class ServerArgs: ...@@ -1187,13 +1190,6 @@ class ServerArgs:
help="The backend for disaggregation transfer. Default is mooncake.", help="The backend for disaggregation transfer. Default is mooncake.",
) )
# Multimodal
parser.add_argument(
"--disable-fast-image-processor",
action="store_true",
help="Adopt base image processor instead of fast image processor.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
args.tp_size = args.tensor_parallel_size args.tp_size = args.tensor_parallel_size
......
...@@ -55,7 +55,6 @@ import torch.distributed ...@@ -55,7 +55,6 @@ import torch.distributed
import torch.distributed as dist import torch.distributed as dist
import triton import triton
import zmq import zmq
from decord import VideoReader, cpu
from fastapi.responses import ORJSONResponse from fastapi.responses import ORJSONResponse
from packaging import version as pkg_version from packaging import version as pkg_version
from PIL import Image from PIL import Image
...@@ -545,6 +544,9 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra ...@@ -545,6 +544,9 @@ def load_audio(audio_file: str, sr: int = 16000, mono: bool = True) -> np.ndarra
def encode_video(video_path, frame_count_limit=None): def encode_video(video_path, frame_count_limit=None):
# Lazy import because decord is not available on some arm platforms.
from decord import VideoReader, cpu
if not os.path.exists(video_path): if not os.path.exists(video_path):
logger.error(f"Video {video_path} does not exist") logger.error(f"Video {video_path} does not exist")
return [] return []
......
...@@ -26,8 +26,8 @@ from transformers import ( ...@@ -26,8 +26,8 @@ from transformers import (
AutoProcessor, AutoProcessor,
) )
from sglang.srt.entrypoints.engine import Engine
from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.srt.server import Engine
from sglang.srt.utils import load_image from sglang.srt.utils import load_image
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import pytest import pytest
import torch import torch
from sglang.srt.custom_op import scaled_fp8_quant from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
......
...@@ -93,9 +93,7 @@ class TestPerTokenGroupQuantFP8(TestFP8Base): ...@@ -93,9 +93,7 @@ class TestPerTokenGroupQuantFP8(TestFP8Base):
A, A_quant_gt, scale_gt = self._make_A( A, A_quant_gt, scale_gt = self._make_A(
M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type M=self.M, K=self.K, group_size=self.group_size, out_dtype=self.quant_type
) )
A_quant, scale = per_token_group_quant_fp8( A_quant, scale = per_token_group_quant_fp8(x=A, group_size=self.group_size)
x=A, group_size=self.group_size, dtype=self.quant_type
)
torch.testing.assert_close(scale, scale_gt) torch.testing.assert_close(scale, scale_gt)
diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs() diff = (A_quant.to(torch.float16) - A_quant_gt.to(torch.float16)).abs()
diff_count = (diff > 1e-5).count_nonzero() diff_count = (diff > 1e-5).count_nonzero()
......
...@@ -3,9 +3,9 @@ import unittest ...@@ -3,9 +3,9 @@ import unittest
import torch import torch
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -41,7 +41,7 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): ...@@ -41,7 +41,7 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
B, D = a.shape B, D = a.shape
# Perform per-token quantization # Perform per-token quantization
a_q, a_s = sgl_scaled_fp8_quant(a, use_per_token_if_dynamic=True) a_q, a_s = scaled_fp8_quant(a, use_per_token_if_dynamic=True)
# Repeat tokens to match topk # Repeat tokens to match topk
a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) a_q = a_q.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
# Also repeat the scale # Also repeat the scale
...@@ -69,7 +69,7 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk): ...@@ -69,7 +69,7 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
# Activation function # Activation function
act_out = SiluAndMul().forward_native(inter_out) act_out = SiluAndMul().forward_native(inter_out)
# Quantize activation output with per-token # Quantize activation output with per-token
act_out_q, act_out_s = sgl_scaled_fp8_quant( act_out_q, act_out_s = scaled_fp8_quant(
act_out, use_per_token_if_dynamic=True act_out, use_per_token_if_dynamic=True
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment