"tests/vscode:/vscode.git/clone" did not exist on "628d00cd7b06c9706b0613aafaefe927fb255877"
Commit 28a8a733 authored by zhuwenwen's avatar zhuwenwen
Browse files

skip ggml_dequantize

parent f9b567df
......@@ -2,7 +2,7 @@ import contextlib
import functools
import importlib
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Type
import torch
import torch.library
......@@ -863,32 +863,32 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
memory_format=torch.contiguous_format)
if hasattr(torch.ops._C, "ggml_dequantize"):
# if hasattr(torch.ops._C, "ggml_dequantize"):
@register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
m: torch.SymInt,
n: torch.SymInt) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device)
# @register_fake("_C::ggml_dequantize")
# def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
# m: torch.SymInt,
# n: torch.SymInt) -> torch.Tensor:
# return torch.empty((m, n), dtype=torch.float16, device=W.device)
@register_fake("_C::ggml_mul_mat_vec_a8")
def _ggml_mul_mat_vec_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: torch.SymInt,
) -> torch.Tensor:
return torch.empty((1, row), dtype=torch.float16, device=W.device)
@register_fake("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: torch.SymInt,
) -> torch.Tensor:
batch = X.size(0)
return torch.empty((batch, row), dtype=torch.float16, device=W.device)
# @register_fake("_C::ggml_mul_mat_vec_a8")
# def _ggml_mul_mat_vec_a8_fake(
# W: torch.Tensor,
# X: torch.Tensor,
# quant_type: int,
# row: torch.SymInt,
# ) -> torch.Tensor:
# return torch.empty((1, row), dtype=torch.float16, device=W.device)
# @register_fake("_C::ggml_mul_mat_a8")
# def _ggml_mul_mat_a8_fake(
# W: torch.Tensor,
# X: torch.Tensor,
# quant_type: int,
# row: torch.SymInt,
# ) -> torch.Tensor:
# batch = X.size(0)
# return torch.empty((batch, row), dtype=torch.float16, device=W.device)
# cutlass
......@@ -1231,9 +1231,9 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# gguf
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
n: int) -> torch.Tensor:
return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
# def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
# n: int) -> torch.Tensor:
# return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
def ggml_mul_mat_vec_a8(
......
......@@ -15,6 +15,7 @@ from vllm.platforms import current_platform
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import (enable_trace_function_call_for_thread,
resolve_obj_by_qualname, update_environment_variables)
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase,
ModelRunnerInputBase)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment