Commit 9a4a94ee authored by zhuwenwen's avatar zhuwenwen
Browse files

add ggml_dequantize

parent 76572db3
...@@ -976,52 +976,52 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, ...@@ -976,52 +976,52 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
# return torch.empty((m, n), device=a.device, dtype=a.dtype) # return torch.empty((m, n), device=a.device, dtype=a.dtype)
# if hasattr(torch.ops._C, "ggml_dequantize"): if hasattr(torch.ops._C, "ggml_dequantize"):
# @register_fake("_C::ggml_dequantize") @register_fake("_C::ggml_dequantize")
# def _ggml_dequantize_fake( def _ggml_dequantize_fake(
# W: torch.Tensor, W: torch.Tensor,
# quant_type: int, quant_type: int,
# m: torch.SymInt, m: torch.SymInt,
# n: torch.SymInt, n: torch.SymInt,
# dtype: Optional[torch.dtype] = None) -> torch.Tensor: dtype: Optional[torch.dtype] = None) -> torch.Tensor:
# return torch.empty((m, n), dtype=torch.float16, device=W.device) return torch.empty((m, n), dtype=torch.float16, device=W.device)
# @register_fake("_C::ggml_mul_mat_vec_a8") @register_fake("_C::ggml_mul_mat_vec_a8")
# def _ggml_mul_mat_vec_a8_fake( def _ggml_mul_mat_vec_a8_fake(
# W: torch.Tensor, W: torch.Tensor,
# X: torch.Tensor, X: torch.Tensor,
# quant_type: int, quant_type: int,
# row: torch.SymInt, row: torch.SymInt,
# ) -> torch.Tensor: ) -> torch.Tensor:
# return torch.empty((X.shape[0], row), dtype=X.dtype, device=W.device) return torch.empty((X.shape[0], row), dtype=X.dtype, device=W.device)
# @register_fake("_C::ggml_mul_mat_a8") @register_fake("_C::ggml_mul_mat_a8")
# def _ggml_mul_mat_a8_fake( def _ggml_mul_mat_a8_fake(
# W: torch.Tensor, W: torch.Tensor,
# X: torch.Tensor, X: torch.Tensor,
# quant_type: int, quant_type: int,
# row: torch.SymInt, row: torch.SymInt,
# ) -> torch.Tensor: ) -> torch.Tensor:
# batch = X.size(0) batch = X.size(0)
# return torch.empty((batch, row), dtype=X.dtype, device=W.device) return torch.empty((batch, row), dtype=X.dtype, device=W.device)
# @register_fake("_C::ggml_moe_a8") @register_fake("_C::ggml_moe_a8")
# def _ggml_moe_a8_fake( def _ggml_moe_a8_fake(
# X: torch.Tensor, X: torch.Tensor,
# W: torch.Tensor, W: torch.Tensor,
# sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
# expert_ids: torch.Tensor, expert_ids: torch.Tensor,
# num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
# quant_type: int, quant_type: int,
# row: torch.SymInt, row: torch.SymInt,
# top_k: torch.SymInt, top_k: torch.SymInt,
# tokens: torch.SymInt, tokens: torch.SymInt,
# ) -> torch.Tensor: ) -> torch.Tensor:
# tokens = X.size(0) tokens = X.size(0)
# return torch.empty((tokens * top_k, row), return torch.empty((tokens * top_k, row),
# dtype=torch.float16, dtype=torch.float16,
# device=W.device) device=W.device)
if hasattr(torch.ops._C, "ggml_moe_a8_vec"): if hasattr(torch.ops._C, "ggml_moe_a8_vec"):
...@@ -1852,9 +1852,9 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -1852,9 +1852,9 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# gguf # gguf
# def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int, def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int,
# dtype: Optional[torch.dtype]) -> torch.Tensor: dtype: Optional[torch.dtype]) -> torch.Tensor:
# return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype) return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype)
def ggml_mul_mat_vec_a8( def ggml_mul_mat_vec_a8(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment