Commit 9a4a94ee authored by zhuwenwen's avatar zhuwenwen
Browse files

add ggml_dequantize

parent 76572db3
......@@ -976,52 +976,52 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
# return torch.empty((m, n), device=a.device, dtype=a.dtype)
# if hasattr(torch.ops._C, "ggml_dequantize"):
# @register_fake("_C::ggml_dequantize")
# def _ggml_dequantize_fake(
# W: torch.Tensor,
# quant_type: int,
# m: torch.SymInt,
# n: torch.SymInt,
# dtype: Optional[torch.dtype] = None) -> torch.Tensor:
# return torch.empty((m, n), dtype=torch.float16, device=W.device)
# @register_fake("_C::ggml_mul_mat_vec_a8")
# def _ggml_mul_mat_vec_a8_fake(
# W: torch.Tensor,
# X: torch.Tensor,
# quant_type: int,
# row: torch.SymInt,
# ) -> torch.Tensor:
# return torch.empty((X.shape[0], row), dtype=X.dtype, device=W.device)
# @register_fake("_C::ggml_mul_mat_a8")
# def _ggml_mul_mat_a8_fake(
# W: torch.Tensor,
# X: torch.Tensor,
# quant_type: int,
# row: torch.SymInt,
# ) -> torch.Tensor:
# batch = X.size(0)
# return torch.empty((batch, row), dtype=X.dtype, device=W.device)
# @register_fake("_C::ggml_moe_a8")
# def _ggml_moe_a8_fake(
# X: torch.Tensor,
# W: torch.Tensor,
# sorted_token_ids: torch.Tensor,
# expert_ids: torch.Tensor,
# num_tokens_post_padded: torch.Tensor,
# quant_type: int,
# row: torch.SymInt,
# top_k: torch.SymInt,
# tokens: torch.SymInt,
# ) -> torch.Tensor:
# tokens = X.size(0)
# return torch.empty((tokens * top_k, row),
# dtype=torch.float16,
# device=W.device)
if hasattr(torch.ops._C, "ggml_dequantize"):
@register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(
W: torch.Tensor,
quant_type: int,
m: torch.SymInt,
n: torch.SymInt,
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device)
@register_fake("_C::ggml_mul_mat_vec_a8")
def _ggml_mul_mat_vec_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: torch.SymInt,
) -> torch.Tensor:
return torch.empty((X.shape[0], row), dtype=X.dtype, device=W.device)
@register_fake("_C::ggml_mul_mat_a8")
def _ggml_mul_mat_a8_fake(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: torch.SymInt,
) -> torch.Tensor:
batch = X.size(0)
return torch.empty((batch, row), dtype=X.dtype, device=W.device)
@register_fake("_C::ggml_moe_a8")
def _ggml_moe_a8_fake(
X: torch.Tensor,
W: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
quant_type: int,
row: torch.SymInt,
top_k: torch.SymInt,
tokens: torch.SymInt,
) -> torch.Tensor:
tokens = X.size(0)
return torch.empty((tokens * top_k, row),
dtype=torch.float16,
device=W.device)
if hasattr(torch.ops._C, "ggml_moe_a8_vec"):
......@@ -1852,9 +1852,9 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# gguf
# def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int,
# dtype: Optional[torch.dtype]) -> torch.Tensor:
# return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype)
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int,
dtype: Optional[torch.dtype]) -> torch.Tensor:
return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype)
def ggml_mul_mat_vec_a8(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment