Commit 6841663b authored by zhushuang's avatar zhushuang
Browse files

issue/972 - feat: adjust scaled_mm_int8 python test

parent e1974c6b
......@@ -25,7 +25,6 @@ from enum import Enum, auto
# These are not meant to be imported from other modules
_TEST_CASES_ = [
# x_shape, w_shape, y_shape, alpha, beta
((2, 4), (4, 2), (2, 2)),
((128, 512), (512, 1024), (128, 1024)),
((256, 1024), (1024, 2048), (256, 2048)),
((1024, 2048), (2048, 1024), (1024, 1024)),
......@@ -83,12 +82,16 @@ def test(
sync=None,
):
print(
f"Testing Linear on {InfiniDeviceNames[device]} with x_shape:{x_shape}, w_shape:{w_shape}, inplace:{inplace} dtype:{InfiniDtypeNames[dtype]}"
f"Testing scaled_mm_int8 on {InfiniDeviceNames[device]} with x_shape:{x_shape}, w_shape:{w_shape}, inplace:{inplace} dtype:{InfiniDtypeNames[dtype]}"
)
M, K = x_shape
N = w_shape[1]
x_packed = TestTensor(
# --- Tensor Descriptor ---
# orig: create a random int8 tensor as the reference data source
# torch: extract the torch view to adjust layout/stride
# final: wrap it back as TestTensor with explicit stride for device execution
x_packed_orig = TestTensor(
(M, K),
None,
InfiniDtype.I8,
......@@ -97,8 +100,18 @@ def test(
randint_low=-128,
randint_high=127,
)
weights = TestTensor(
(K, N),
x_packed_torch = x_packed_orig.torch_tensor()
x_packed = TestTensor(
(M, K),
x_packed_torch.stride(),
InfiniDtype.I8,
device,
mode="manual",
set_tensor=x_packed_torch,
)
weights_orig = TestTensor(
(N, K),
None,
InfiniDtype.I8,
device,
......@@ -106,9 +119,44 @@ def test(
randint_low=-128,
randint_high=127,
)
x_scale = TestTensor((M,), None, InfiniDtype.F32, device, mode="random")
weights_scale = TestTensor((N,), None, InfiniDtype.F32, device, mode="random")
bias = TestTensor((N,), None, dtype, device, mode="random")
weights_torch = weights_orig.torch_tensor().t()
weights = TestTensor(
(K, N),
weights_torch.stride(),
InfiniDtype.I8,
device,
mode="manual",
set_tensor=weights_torch,
)
x_scale_orig = TestTensor((M,), None, InfiniDtype.F32, device, mode="random")
x_scale_torch = x_scale_orig.torch_tensor()
x_scale = TestTensor(
(M,),
x_scale_torch.stride(),
InfiniDtype.F32,
device,
mode="manual",
set_tensor=x_scale_torch,
)
weights_scale_orig = TestTensor((N,), None, InfiniDtype.F32, device, mode="random")
weights_scale_torch = weights_scale_orig.torch_tensor()
weights_scale = TestTensor(
(N,),
weights_scale_torch.stride(),
InfiniDtype.F32,
device,
mode="manual",
set_tensor=weights_scale_torch,
)
bias_orig = TestTensor((N,), None, dtype, device, mode="random")
bias_torch = bias_orig.torch_tensor()
bias = TestTensor(
(N,), bias_torch.stride(), dtype, device, mode="manual", set_tensor=bias_torch
)
y = TestTensor(y_shape, None, dtype, device, mode="zeros")
ans = torch_scaled_mm(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment