Unverified Commit 2b0fc594 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

[Minor] Code style improvements (#2355)

parent 9cc733b3
...@@ -2,12 +2,10 @@ ...@@ -2,12 +2,10 @@
Common utilities for torchao. Common utilities for torchao.
""" """
from typing import Dict, Set
import torch import torch
def apply_torchao_config_to_model_( def apply_torchao_config_to_model(
model: torch.nn.Module, torchao_config: str, filter_fn=None model: torch.nn.Module, torchao_config: str, filter_fn=None
): ):
"""Quantize a modelwith torchao quantization specified by torchao_config """Quantize a modelwith torchao quantization specified by torchao_config
...@@ -21,6 +19,7 @@ def apply_torchao_config_to_model_( ...@@ -21,6 +19,7 @@ def apply_torchao_config_to_model_(
# Lazy import to suppress some warnings # Lazy import to suppress some warnings
from torchao.quantization import ( from torchao.quantization import (
float8_dynamic_activation_float8_weight, float8_dynamic_activation_float8_weight,
float8_weight_only,
int4_weight_only, int4_weight_only,
int8_dynamic_activation_int8_weight, int8_dynamic_activation_int8_weight,
int8_weight_only, int8_weight_only,
...@@ -28,6 +27,11 @@ def apply_torchao_config_to_model_( ...@@ -28,6 +27,11 @@ def apply_torchao_config_to_model_(
) )
from torchao.quantization.observer import PerRow, PerTensor from torchao.quantization.observer import PerRow, PerTensor
if filter_fn is None:
def filter_fn(module, fqn):
return "proj" in fqn
if torchao_config == "" or torchao_config is None: if torchao_config == "" or torchao_config is None:
return model return model
elif "int8wo" in torchao_config: elif "int8wo" in torchao_config:
...@@ -44,8 +48,6 @@ def apply_torchao_config_to_model_( ...@@ -44,8 +48,6 @@ def apply_torchao_config_to_model_(
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn) quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
elif "fp8wo" in torchao_config: elif "fp8wo" in torchao_config:
from torchao.quantization import float8_weight_only
# this requires newer hardware # this requires newer hardware
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
quantize_(model, float8_weight_only(), filter_fn=filter_fn) quantize_(model, float8_weight_only(), filter_fn=filter_fn)
......
...@@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int): ...@@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
if "FusedMoE" in sub.__class__.__name__: if "FusedMoE" in sub.__class__.__name__:
if batch_size == 1: if batch_size == 1:
# The performance of torch.compile on this layer is not always good when bs > 1, # The performance of torch.compile on this layer is not always good when bs > 1,
# so we decide to skip it for now. # so we decide to only use torch.compile when bs =1
sub._forward_method = fused_moe_forward_native sub._forward_method = fused_moe_forward_native
else: else:
sub._forward_method = sub.forward_native sub._forward_method = sub.forward_native
......
...@@ -27,7 +27,6 @@ from vllm.distributed import ( ...@@ -27,7 +27,6 @@ from vllm.distributed import (
initialize_model_parallel, initialize_model_parallel,
set_custom_all_reduce, set_custom_all_reduce,
) )
from vllm.distributed.parallel_state import in_the_same_node_as
from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
...@@ -38,7 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack ...@@ -38,7 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model_ from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
...@@ -112,11 +111,13 @@ class ModelRunner: ...@@ -112,11 +111,13 @@ class ModelRunner:
) )
if self.is_multimodal: if self.is_multimodal:
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args.chunked_prefill_size = -1 server_args.chunked_prefill_size = -1
self.mem_fraction_static *= 0.95 self.mem_fraction_static *= 0.95
logger.info(
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static} "
f"and turn off chunked prefill "
f"because this is a multimodal model."
)
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
if self.model_config.hf_config.architectures == [ if self.model_config.hf_config.architectures == [
"Qwen2VLForConditionalGeneration" "Qwen2VLForConditionalGeneration"
...@@ -160,11 +161,8 @@ class ModelRunner: ...@@ -160,11 +161,8 @@ class ModelRunner:
else: else:
self.torch_tp_applied = False self.torch_tp_applied = False
def filter_fn(module, fqn): apply_torchao_config_to_model(
return "proj" in fqn self.model, global_server_args_dict["torchao_config"]
apply_torchao_config_to_model_(
self.model, global_server_args_dict["torchao_config"], filter_fn
) )
# Init memory pool and attention backends # Init memory pool and attention backends
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment