"vscode:/vscode.git/clone" did not exist on "c09dade2a263b6f684d2fbf390c9c1c64761e953"
Commit 28a8a733 authored by zhuwenwen's avatar zhuwenwen
Browse files

skip ggml_dequantize

parent f9b567df
...@@ -2,7 +2,7 @@ import contextlib ...@@ -2,7 +2,7 @@ import contextlib
import functools import functools
import importlib import importlib
from typing import TYPE_CHECKING, List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Type
import torch import torch
import torch.library import torch.library
...@@ -863,32 +863,32 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -863,32 +863,32 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
memory_format=torch.contiguous_format) memory_format=torch.contiguous_format)
if hasattr(torch.ops._C, "ggml_dequantize"): # if hasattr(torch.ops._C, "ggml_dequantize"):
@register_fake("_C::ggml_dequantize") # @register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, # def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
m: torch.SymInt, # m: torch.SymInt,
n: torch.SymInt) -> torch.Tensor: # n: torch.SymInt) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device) # return torch.empty((m, n), dtype=torch.float16, device=W.device)
@register_fake("_C::ggml_mul_mat_vec_a8") # @register_fake("_C::ggml_mul_mat_vec_a8")
def _ggml_mul_mat_vec_a8_fake( # def _ggml_mul_mat_vec_a8_fake(
W: torch.Tensor, # W: torch.Tensor,
X: torch.Tensor, # X: torch.Tensor,
quant_type: int, # quant_type: int,
row: torch.SymInt, # row: torch.SymInt,
) -> torch.Tensor: # ) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device) # return torch.empty((1, row), dtype=torch.float16, device=W.device)
@register_fake("_C::ggml_mul_mat_a8") # @register_fake("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake( # def _ggml_mul_mat_a8_fake(
W: torch.Tensor, # W: torch.Tensor,
X: torch.Tensor, # X: torch.Tensor,
quant_type: int, # quant_type: int,
row: torch.SymInt, # row: torch.SymInt,
) -> torch.Tensor: # ) -> torch.Tensor:
batch = X.size(0) # batch = X.size(0)
return torch.empty((batch, row), dtype=torch.float16, device=W.device) # return torch.empty((batch, row), dtype=torch.float16, device=W.device)
# cutlass # cutlass
...@@ -1231,9 +1231,9 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -1231,9 +1231,9 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# gguf # gguf
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, # def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
n: int) -> torch.Tensor: # n: int) -> torch.Tensor:
return torch.ops._C.ggml_dequantize(W, quant_type, m, n) # return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
def ggml_mul_mat_vec_a8( def ggml_mul_mat_vec_a8(
......
...@@ -15,6 +15,7 @@ from vllm.platforms import current_platform ...@@ -15,6 +15,7 @@ from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import (enable_trace_function_call_for_thread, from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, update_environment_variables) resolve_obj_by_qualname, update_environment_variables)
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner_base import (BroadcastableModelInput, from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase, ModelRunnerBase,
ModelRunnerInputBase) ModelRunnerInputBase)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment