"test/verify/test_unbatched_gemm_1.cpp" did not exist on "985f58b009280b531e80fd7f95b5135ef3d8ecd1"
Commit 29a90944 authored by Xtra's avatar Xtra Committed by GitHub
Browse files

add mxfp4 kernels and rename some func for clarity (#148)

parent 505c5a47
import torch
from lightx2v_kernel.gemm import scaled_fp4_quant, cutlass_scaled_fp4_mm
from lightx2v_kernel.gemm import scaled_nvfp4_quant, cutlass_scaled_nvfp4_mm
FLOAT4_E2M1_MAX = 6.0
......@@ -110,8 +110,8 @@ def test_nvfp4_gemm(
print(f"b_global_scale : {b_global_scale}, {b_global_scale.shape}")
alpha = 1.0 / (a_global_scale * b_global_scale)
a_fp4, a_scale_interleaved = scaled_fp4_quant(a_dtype, a_global_scale)
b_fp4, b_scale_interleaved = scaled_fp4_quant(b_dtype, b_global_scale)
a_fp4, a_scale_interleaved = scaled_nvfp4_quant(a_dtype, a_global_scale)
b_fp4, b_scale_interleaved = scaled_nvfp4_quant(b_dtype, b_global_scale)
expected_out = get_ref_results(
a_fp4,
......@@ -130,7 +130,7 @@ def test_nvfp4_gemm(
print(f"alpha {alpha}, {alpha.shape}, {alpha.dtype}")
out = cutlass_scaled_fp4_mm(a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, bias)
out = cutlass_scaled_nvfp4_mm(a_fp4, b_fp4, a_scale_interleaved, b_scale_interleaved, alpha, bias)
print(f"out : {out}, {out.shape}, {out.dtype}")
print(f"expected_out : {expected_out}, {expected_out.shape}, {expected_out.dtype}")
......
import torch
from lightx2v_kernel.gemm import scaled_fp4_quant, cutlass_scaled_fp4_mm
from lightx2v_kernel.gemm import scaled_nvfp4_quant, cutlass_scaled_nvfp4_mm
import time
......@@ -14,13 +14,13 @@ class MMWeightFp4:
@torch.no_grad()
def apply(self, input_tensor):
input_tensor_quant, input_tensor_scale = self.act_quant_func(input_tensor)
output_tensor = cutlass_scaled_fp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
output_tensor = cutlass_scaled_nvfp4_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale, alpha=self.alpha, bias=self.bias)
return output_tensor
@torch.no_grad()
def load_fp4_weight(self, weight, bias):
self.weight_global_scale = (2688.0 / torch.max(torch.abs(weight))).to(torch.float32)
self.weight, self.weight_scale = scaled_fp4_quant(weight, self.weight_global_scale)
self.weight, self.weight_scale = scaled_nvfp4_quant(weight, self.weight_global_scale)
self.bias = bias
def calibrate_x_absmax(self):
......@@ -30,7 +30,7 @@ class MMWeightFp4:
@torch.no_grad()
def act_quant_fp4(self, x):
return scaled_fp4_quant(x, self.input_global_scale)
return scaled_nvfp4_quant(x, self.input_global_scale)
def test_speed(m, k, n):
......
import torch
from lightx2v_kernel.gemm import scaled_fp4_quant, cutlass_scaled_fp4_mm
import time
from test_bench2 import MMWeightFp4
......
import torch
from lightx2v_kernel.gemm import cutlass_scaled_fp4_mm
from lightx2v_kernel.gemm import cutlass_scaled_nvfp4_mm
"""
......@@ -16,7 +16,7 @@ bias = None
def test_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha, bias):
output_tensor = cutlass_scaled_fp4_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha=alpha, bias=bias)
output_tensor = cutlass_scaled_nvfp4_mm(input_tensor_quant, weight, input_tensor_scale, weight_scale, alpha=alpha, bias=bias)
return output_tensor
......
import torch
from lightx2v_kernel.gemm import scaled_fp4_quant
from lightx2v_kernel.gemm import scaled_nvfp4_quant
input_global_scale = torch.tensor(808.0, dtype=torch.float32).cuda()
def quantize_fp4(x):
return scaled_fp4_quant(x, input_global_scale)
return scaled_nvfp4_quant(x, input_global_scale)
def test_memory_bandwidth(func, x, num_warmup=10, num_runs=100):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment