Commit cffe15ef authored by zhuwenwen's avatar zhuwenwen
Browse files

update deps,fix import and optional error, remove tc

parent c004bf6e
...@@ -25,9 +25,9 @@ quart ...@@ -25,9 +25,9 @@ quart
fastrlock==0.8.3 fastrlock==0.8.3
cupy==12.3.0 cupy==12.3.0
torch == 2.5.1 torch == 2.7.1
triton == 3.1 triton == 3.1
flash_attn == 2.6.1 flash_attn == 2.6.1
flash_mla == 1.0.0 flash_mla == 1.0.0
lightop == 0.6.0 lightop == 0.6.0
lmslim == 0.3.1 # lmslim == 0.3.1
...@@ -10,7 +10,7 @@ from vllm.logger import init_logger ...@@ -10,7 +10,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType from vllm.scalar_type import ScalarType
from vllm.utils import direct_register_custom_op from vllm.utils.torch_utils import direct_register_custom_op
try: try:
from lmslim import quant_ops from lmslim import quant_ops
...@@ -472,10 +472,6 @@ def GetAWQShareWorkspaceSize()->int: ...@@ -472,10 +472,6 @@ def GetAWQShareWorkspaceSize()->int:
def GetAWQShareWorkspace()->torch.Tensor: def GetAWQShareWorkspace()->torch.Tensor:
return quant_ops.GetAWQShareWorkspace() return quant_ops.GetAWQShareWorkspace()
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
def awq_dequantize( def awq_dequantize(
qweight: torch.Tensor, qweight: torch.Tensor,
scales: torch.Tensor, scales: torch.Tensor,
...@@ -928,7 +924,7 @@ def rocblas_scaled_mm(a: torch.Tensor, ...@@ -928,7 +924,7 @@ def rocblas_scaled_mm(a: torch.Tensor,
scale_a: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype, out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: torch.Tensor | None = None) -> torch.Tensor:
# cutlass_compatible_b = b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0 # cutlass_compatible_b = b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0
# if current_platform.is_rocm() or not cutlass_compatible_b: # if current_platform.is_rocm() or not cutlass_compatible_b:
...@@ -947,7 +943,7 @@ def blaslt_scaled_mm(a: torch.Tensor, ...@@ -947,7 +943,7 @@ def blaslt_scaled_mm(a: torch.Tensor,
scale_a: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype, out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: torch.Tensor | None = None) -> torch.Tensor:
m = a.shape[0] m = a.shape[0]
n = b.shape[0] n = b.shape[0]
k = a.shape[1] k = a.shape[1]
...@@ -961,8 +957,8 @@ def triton_scaled_mm(a: torch.Tensor, ...@@ -961,8 +957,8 @@ def triton_scaled_mm(a: torch.Tensor,
scale_a: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype, out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None, bias: torch.Tensor | None = None,
best_config:Optional[list] = None) -> torch.Tensor: best_config: list | None = None) -> torch.Tensor:
return quant_ops.triton_scaled_mm(a, b,scale_a,scale_b,out_dtype,bias,best_config) return quant_ops.triton_scaled_mm(a, b,scale_a,scale_b,out_dtype,bias,best_config)
...@@ -974,8 +970,8 @@ def triton_int8_gemm_helper(m: int, ...@@ -974,8 +970,8 @@ def triton_int8_gemm_helper(m: int,
use_bias: bool, use_bias: bool,
out_dtype: type[torch.dtype] = torch.float16, out_dtype: type[torch.dtype] = torch.float16,
device: str = "cuda:0", device: str = "cuda:0",
best_config:Optional[list] = None, best_config: list | None = None,
repeat:Optional[int] = 2): repeat: int | None = 2):
return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config,repeat) return quant_tools.triton_int8_gemm_helper(m,n,k,per_token_act_quant,per_out_channel_weight_quant,use_bias,out_dtype,device,best_config,repeat)
def triton_blockint8_gemm_helper(m: int, def triton_blockint8_gemm_helper(m: int,
...@@ -985,8 +981,8 @@ def triton_blockint8_gemm_helper(m: int, ...@@ -985,8 +981,8 @@ def triton_blockint8_gemm_helper(m: int,
use_bias: bool=False, use_bias: bool=False,
out_dtype: type[torch.dtype] = torch.bfloat16, out_dtype: type[torch.dtype] = torch.bfloat16,
device: str = "cuda:0", device: str = "cuda:0",
best_config:Optional[dict] = None, best_config: dict | None = None,
repeat:Optional[int] = 2): repeat: int | None = 2):
return quant_tools.triton_blockint8_gemm_helper(m,n,k,block_size,use_bias,out_dtype,device,best_config,repeat) return quant_tools.triton_blockint8_gemm_helper(m,n,k,block_size,use_bias,out_dtype,device,best_config,repeat)
......
...@@ -7,7 +7,6 @@ import torch ...@@ -7,7 +7,6 @@ import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm import envs from vllm import envs
from vllm.utils import SUPPORT_TC
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
if current_platform.is_cuda_alike(): if current_platform.is_cuda_alike():
...@@ -18,7 +17,7 @@ elif current_platform.is_xpu(): ...@@ -18,7 +17,7 @@ elif current_platform.is_xpu():
if HAS_TRITON: if HAS_TRITON:
from vllm.attention.ops.prefix_prefill import context_attention_fwd from vllm.attention.ops.prefix_prefill import context_attention_fwd
use_tc = envs.VLLM_USE_OPT_OP and envs.VLLM_USE_TC_PAGED_ATTN and SUPPORT_TC use_tc = envs.VLLM_USE_OPT_OP and envs.VLLM_USE_TC_PAGED_ATTN
class PagedAttention: class PagedAttention:
......
...@@ -11,6 +11,7 @@ from vllm.config.utils import config ...@@ -11,6 +11,7 @@ from vllm.config.utils import config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import get_cpu_memory from vllm.utils.mem_utils import get_cpu_memory
from vllm import envs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config.parallel import ParallelConfig from vllm.config.parallel import ParallelConfig
......
...@@ -49,7 +49,6 @@ from vllm.transformers_utils.utils import ( ...@@ -49,7 +49,6 @@ from vllm.transformers_utils.utils import (
) )
from vllm.utils.import_utils import LazyLoader from vllm.utils.import_utils import LazyLoader
from vllm.utils.torch_utils import common_broadcastable_dtype from vllm.utils.torch_utils import common_broadcastable_dtype
from vllm.utils import SUPPORT_TC
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PretrainedConfig from transformers import PretrainedConfig
...@@ -1674,7 +1673,7 @@ class ModelConfig: ...@@ -1674,7 +1673,7 @@ class ModelConfig:
@property @property
def use_mla(self) -> bool: def use_mla(self) -> bool:
return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE and SUPPORT_TC return self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE
@property @property
def is_matryoshka(self) -> bool: def is_matryoshka(self) -> bool:
......
...@@ -29,7 +29,6 @@ from vllm.model_executor.parameter import BasevLLMParameter ...@@ -29,7 +29,6 @@ from vllm.model_executor.parameter import BasevLLMParameter
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import SUPPORT_TC
DEFAULT_VOCAB_PADDING_SIZE = 64 DEFAULT_VOCAB_PADDING_SIZE = 64
...@@ -185,42 +184,15 @@ class VocabParallelEmbeddingShardIndices: ...@@ -185,42 +184,15 @@ class VocabParallelEmbeddingShardIndices:
assert self.num_added_elements <= self.num_added_elements_padded assert self.num_added_elements <= self.num_added_elements_padded
if SUPPORT_TC: @torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend) def get_masked_input_and_mask(
def get_masked_input_and_mask(
input_: torch.Tensor, input_: torch.Tensor,
org_vocab_start_index: int, org_vocab_start_index: int,
org_vocab_end_index: int, org_vocab_end_index: int,
num_org_vocab_padding: int, num_org_vocab_padding: int,
added_vocab_start_index: int, added_vocab_start_index: int,
added_vocab_end_index: int, added_vocab_end_index: int,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
added_vocab_mask = (input_ >= added_vocab_start_index) & (
input_ < added_vocab_end_index
)
added_offset = (
added_vocab_start_index
- (org_vocab_end_index - org_vocab_start_index)
- num_org_vocab_padding
)
valid_offset = (org_vocab_start_index * org_vocab_mask) + (
added_offset * added_vocab_mask
)
vocab_mask = org_vocab_mask | added_vocab_mask
input_ = vocab_mask * (input_ - valid_offset)
return input_, ~vocab_mask
else:
def get_masked_input_and_mask(
input_: torch.Tensor,
org_vocab_start_index: int,
org_vocab_end_index: int,
num_org_vocab_padding: int,
added_vocab_start_index: int,
added_vocab_end_index: int,
) -> tuple[torch.Tensor, torch.Tensor]:
# torch.compile will fuse all of the pointwise ops below # torch.compile will fuse all of the pointwise ops below
# into a single kernel, making it very fast # into a single kernel, making it very fast
org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index) org_vocab_mask = (input_ >= org_vocab_start_index) & (input_ < org_vocab_end_index)
......
...@@ -98,10 +98,10 @@ class BasevLLMParameter(Parameter): ...@@ -98,10 +98,10 @@ class BasevLLMParameter(Parameter):
) )
self.data.copy_(loaded_weight) self.data.copy_(loaded_weight)
def load_column_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = False): def load_column_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: bool | None = False):
self._assert_and_load(loaded_weight) self._assert_and_load(loaded_weight)
def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = False): def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: bool | None = False):
self._assert_and_load(loaded_weight) self._assert_and_load(loaded_weight)
def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs): def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
...@@ -147,7 +147,7 @@ class _ColumnvLLMParameter(BasevLLMParameter): ...@@ -147,7 +147,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
def output_dim(self): def output_dim(self):
return self._output_dim return self._output_dim
def load_column_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = False): def load_column_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: bool | None = False):
if not envs.VLLM_USE_NN or len( self.data.shape)==1 or is_quantization: if not envs.VLLM_USE_NN or len( self.data.shape)==1 or is_quantization:
shard_size = self.data.shape[self.output_dim] shard_size = self.data.shape[self.output_dim]
else: else:
...@@ -240,7 +240,7 @@ class RowvLLMParameter(BasevLLMParameter): ...@@ -240,7 +240,7 @@ class RowvLLMParameter(BasevLLMParameter):
def input_dim(self): def input_dim(self):
return self._input_dim return self._input_dim
def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: Optional[bool] = False): def load_row_parallel_weight(self, loaded_weight: torch.Tensor, is_quantization: bool | None = False):
if not envs.VLLM_USE_NN or is_quantization: if not envs.VLLM_USE_NN or is_quantization:
shard_size = self.data.shape[self.input_dim] shard_size = self.data.shape[self.input_dim]
else: else:
......
...@@ -14,12 +14,6 @@ from vllm.utils.torch_utils import cuda_device_count_stateless ...@@ -14,12 +14,6 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum from .interface import DeviceCapability, Platform, PlatformEnum
from vllm.utils import SUPPORT_TC
if not SUPPORT_TC:
os.environ['VLLM_USE_V1'] = '0'
os.environ['VLLM_USE_FLASH_ATTN_PA'] = '0'
os.environ['VLLM_USE_FLASH_MLA'] = '0'
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment