Unverified Commit 5dc54f1a authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

feat: remove vllm distributed (#2907)


Co-authored-by: default avatarZhangyi <1109276519@qq.com>
parent f3e9b489
...@@ -25,13 +25,13 @@ from sglang.srt.utils import is_flashinfer_available ...@@ -25,13 +25,13 @@ from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available(): if is_flashinfer_available():
from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul from flashinfer.activation import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
from vllm.distributed import ( from vllm.model_executor.custom_op import CustomOp
from sglang.srt.distributed import (
divide, divide,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import set_weight_attrs
......
import torch import torch
from vllm.distributed import GroupCoordinator, get_tp_group
from sglang.srt.distributed import GroupCoordinator, get_tp_group
_ATTN_TP_GROUP = None _ATTN_TP_GROUP = None
_ATTN_TP_RANK = None _ATTN_TP_RANK = None
......
...@@ -7,7 +7,8 @@ from typing import Dict, List, Optional, Tuple ...@@ -7,7 +7,8 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (
from sglang.srt.distributed import (
divide, divide,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -15,7 +16,6 @@ from vllm.distributed import ( ...@@ -15,7 +16,6 @@ from vllm.distributed import (
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.parameter import ( from sglang.srt.layers.parameter import (
BasevLLMParameter, BasevLLMParameter,
PackedColumnParameter, PackedColumnParameter,
......
...@@ -20,11 +20,11 @@ import torch ...@@ -20,11 +20,11 @@ import torch
import triton import triton
import triton.language as tl import triton.language as tl
from torch import nn from torch import nn
from vllm.distributed import (
from sglang.srt.distributed import (
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather, tensor_model_parallel_all_gather,
) )
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
from sglang.srt.model_executor.forward_batch_info import ( from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode, CaptureHiddenMode,
......
...@@ -4,13 +4,13 @@ from typing import Callable, List, Optional, Tuple ...@@ -4,13 +4,13 @@ from typing import Callable, List, Optional, Tuple
import torch import torch
from torch.nn import Module from torch.nn import Module
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.moe.ep_moe.kernels import ( from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton, grouped_gemm_triton,
......
...@@ -5,13 +5,13 @@ from enum import Enum ...@@ -5,13 +5,13 @@ from enum import Enum
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
import torch import torch
from vllm.distributed import ( from vllm.model_executor.custom_op import CustomOp
from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.moe.topk import select_experts
......
...@@ -6,7 +6,8 @@ from typing import Callable, Optional, Union ...@@ -6,7 +6,8 @@ from typing import Callable, Optional, Union
import torch import torch
from torch.nn import Parameter from torch.nn import Parameter
from vllm.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed import get_tensor_model_parallel_rank
__all__ = [ __all__ = [
"BasevLLMParameter", "BasevLLMParameter",
......
...@@ -8,7 +8,6 @@ import torch.nn.functional as F ...@@ -8,7 +8,6 @@ import torch.nn.functional as F
from torch.nn import Module from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
apply_fp8_marlin_linear, apply_fp8_marlin_linear,
...@@ -24,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -24,6 +23,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
requantize_with_max_scale, requantize_with_max_scale,
) )
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
LinearBase, LinearBase,
LinearMethodBase, LinearMethodBase,
......
...@@ -6,13 +6,13 @@ from typing import List, Optional, Sequence, Tuple ...@@ -6,13 +6,13 @@ from typing import List, Optional, Sequence, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.parameter import Parameter, UninitializedParameter from torch.nn.parameter import Parameter, UninitializedParameter
from vllm.distributed import (
from sglang.srt.distributed import (
divide, divide,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from sglang.srt.layers.parameter import BasevLLMParameter from sglang.srt.layers.parameter import BasevLLMParameter
from sglang.srt.layers.quantization.base_config import ( from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
......
...@@ -21,10 +21,10 @@ from typing import TYPE_CHECKING, Callable ...@@ -21,10 +21,10 @@ from typing import TYPE_CHECKING, Callable
import torch import torch
import tqdm import tqdm
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.parallel_state import graph_capture
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
from sglang.srt.layers.torchao_utils import save_gemlite_cache from sglang.srt.layers.torchao_utils import save_gemlite_cache
......
...@@ -21,16 +21,17 @@ from typing import List, Optional, Tuple ...@@ -21,16 +21,17 @@ from typing import List, Optional, Tuple
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from vllm.distributed import (
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.distributed import (
get_tp_group, get_tp_group,
init_distributed_environment, init_distributed_environment,
initialize_model_parallel, initialize_model_parallel,
set_custom_all_reduce, set_custom_all_reduce,
) )
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
...@@ -295,12 +296,15 @@ class ModelRunner: ...@@ -295,12 +296,15 @@ class ModelRunner:
monkey_patch_vllm_gguf_config() monkey_patch_vllm_gguf_config()
# Load the model # Load the model
# Remove monkey_patch when linear.py quant remove dependencies with vllm
monkey_patch_vllm_parallel_state()
with self.memory_saver_adapter.region(): with self.memory_saver_adapter.region():
self.model = get_model( self.model = get_model(
model_config=self.model_config, model_config=self.model_config,
load_config=self.load_config, load_config=self.load_config,
device_config=DeviceConfig(self.device), device_config=DeviceConfig(self.device),
) )
monkey_patch_vllm_parallel_state(reverse=True)
if self.server_args.kv_cache_dtype == "fp8_e4m3": if self.server_args.kv_cache_dtype == "fp8_e4m3":
if self.server_args.quantization_param_path is not None: if self.server_args.quantization_param_path is not None:
......
...@@ -21,14 +21,14 @@ from huggingface_hub import HfApi, hf_hub_download ...@@ -21,14 +21,14 @@ from huggingface_hub import HfApi, hf_hub_download
from torch import nn from torch import nn
from transformers import AutoModelForCausalLM, PretrainedConfig from transformers import AutoModelForCausalLM, PretrainedConfig
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.configs.device_config import DeviceConfig from sglang.srt.configs.device_config import DeviceConfig
from sglang.srt.configs.load_config import LoadConfig, LoadFormat from sglang.srt.configs.load_config import LoadConfig, LoadFormat
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_loader.utils import ( from sglang.srt.model_loader.utils import (
get_model_architecture, get_model_architecture,
...@@ -496,7 +496,8 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -496,7 +496,8 @@ class ShardedStateLoader(BaseModelLoader):
device_config: DeviceConfig, device_config: DeviceConfig,
) -> nn.Module: ) -> nn.Module:
from safetensors.torch import safe_open from safetensors.torch import safe_open
from vllm.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed import get_tensor_model_parallel_rank
local_model_path = self._prepare_weights( local_model_path = self._prepare_weights(
model_config.model_path, model_config.revision model_config.model_path, model_config.revision
...@@ -556,7 +557,8 @@ class ShardedStateLoader(BaseModelLoader): ...@@ -556,7 +557,8 @@ class ShardedStateLoader(BaseModelLoader):
max_size: Optional[int] = None, max_size: Optional[int] = None,
) -> None: ) -> None:
from safetensors.torch import save_file from safetensors.torch import save_file
from vllm.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed import get_tensor_model_parallel_rank
if pattern is None: if pattern is None:
pattern = ShardedStateLoader.DEFAULT_PATTERN pattern = ShardedStateLoader.DEFAULT_PATTERN
......
...@@ -19,10 +19,10 @@ import torch ...@@ -19,10 +19,10 @@ import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm from tqdm.auto import tqdm
from vllm.distributed import get_tensor_model_parallel_rank
from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.load_config import LoadConfig
from sglang.srt.configs.model_config import ModelConfig from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
from sglang.srt.utils import print_warning_once from sglang.srt.utils import print_warning_once
......
...@@ -24,10 +24,6 @@ from typing import Iterable, Optional, Tuple ...@@ -24,10 +24,6 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.linear import ( from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -35,6 +31,10 @@ from vllm.model_executor.layers.linear import ( ...@@ -35,6 +31,10 @@ from vllm.model_executor.layers.linear import (
) )
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
......
...@@ -21,10 +21,10 @@ from typing import Iterable, Optional, Tuple ...@@ -21,10 +21,10 @@ from typing import Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from torch.nn import LayerNorm from torch.nn import LayerNorm
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.configs import ChatGLMConfig from sglang.srt.configs import ChatGLMConfig
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
......
...@@ -44,12 +44,12 @@ import torch.utils.checkpoint ...@@ -44,12 +44,12 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import ( from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
MergedColumnParallelLinear, MergedColumnParallelLinear,
......
...@@ -19,14 +19,14 @@ from typing import Iterable, Optional, Tuple ...@@ -19,14 +19,14 @@ from typing import Iterable, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from vllm.distributed import ( from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.configs import DbrxConfig
from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.configs import DbrxConfig
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
QKVParallelLinear, QKVParallelLinear,
ReplicatedLinear, ReplicatedLinear,
......
...@@ -21,13 +21,13 @@ from typing import Any, Dict, Iterable, Optional, Tuple ...@@ -21,13 +21,13 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.distributed import ( from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
......
...@@ -23,14 +23,14 @@ import torch.nn.functional as F ...@@ -23,14 +23,14 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.distributed import ( from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tp_group, get_tp_group,
tensor_model_parallel_all_reduce, tensor_model_parallel_all_reduce,
) )
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
......
...@@ -20,9 +20,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple ...@@ -20,9 +20,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ( from sglang.srt.layers.linear import (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment