Unverified Commit 4561f139 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Refactor] Rename `gptq_marlin` to `marlin` to match MoE (#32952)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 6cc6d92b
...@@ -591,8 +591,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"): ...@@ -591,8 +591,8 @@ if hasattr(torch.ops._C, "gptq_marlin_24_gemm"):
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype)
@register_fake("_C::gptq_marlin_gemm") @register_fake("_C::marlin_gemm")
def _gptq_marlin_gemm_fake( def _marlin_gemm_fake(
a: torch.Tensor, a: torch.Tensor,
c: torch.Tensor | None, c: torch.Tensor | None,
b_q_weight: torch.Tensor, b_q_weight: torch.Tensor,
...@@ -1312,7 +1312,7 @@ def marlin_int4_fp8_preprocess( ...@@ -1312,7 +1312,7 @@ def marlin_int4_fp8_preprocess(
return torch.ops._C.marlin_int4_fp8_preprocess(qweight, qzeros_or_none, inplace) return torch.ops._C.marlin_int4_fp8_preprocess(qweight, qzeros_or_none, inplace)
def gptq_marlin_gemm( def marlin_gemm(
a: torch.Tensor, a: torch.Tensor,
c: torch.Tensor | None, c: torch.Tensor | None,
b_q_weight: torch.Tensor, b_q_weight: torch.Tensor,
...@@ -1333,7 +1333,7 @@ def gptq_marlin_gemm( ...@@ -1333,7 +1333,7 @@ def gptq_marlin_gemm(
use_fp32_reduce: bool = False, use_fp32_reduce: bool = False,
is_zp_float: bool = False, is_zp_float: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm( return torch.ops._C.marlin_gemm(
a, a,
c, c,
b_q_weight, b_q_weight,
......
...@@ -563,7 +563,7 @@ def apply_gptq_marlin_linear( ...@@ -563,7 +563,7 @@ def apply_gptq_marlin_linear(
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm( output = ops.marlin_gemm(
reshaped_x, reshaped_x,
None, None,
weight, weight,
...@@ -628,7 +628,7 @@ def apply_awq_marlin_linear( ...@@ -628,7 +628,7 @@ def apply_awq_marlin_linear(
) )
reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype) reshaped_x, a_scales = marlin_quant_input(reshaped_x, input_dtype)
output = ops.gptq_marlin_gemm( output = ops.marlin_gemm(
reshaped_x, reshaped_x,
None, None,
weight, weight,
......
...@@ -121,7 +121,7 @@ def apply_fp4_marlin_linear( ...@@ -121,7 +121,7 @@ def apply_fp4_marlin_linear(
inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn) inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
output = ops.gptq_marlin_gemm( output = ops.marlin_gemm(
a=inputs, a=inputs,
c=None, c=None,
b_q_weight=weight, b_q_weight=weight,
......
...@@ -66,7 +66,7 @@ def apply_fp8_marlin_linear( ...@@ -66,7 +66,7 @@ def apply_fp8_marlin_linear(
# inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn) # inputs, a_scales = marlin_quant_input(inputs, torch.float8_e4m3fn)
raise RuntimeError("Marlin W8A8 is not supported.") raise RuntimeError("Marlin W8A8 is not supported.")
output = ops.gptq_marlin_gemm( output = ops.marlin_gemm(
a=inputs, a=inputs,
c=None, c=None,
b_q_weight=weight, b_q_weight=weight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment