Commit 6841663b authored by zhushuang's avatar zhushuang
Browse files

issue/972 - feat: adjust scaled_mm_int8 python test

parent e1974c6b
...@@ -25,7 +25,6 @@ from enum import Enum, auto ...@@ -25,7 +25,6 @@ from enum import Enum, auto
# These are not meant to be imported from other modules # These are not meant to be imported from other modules
_TEST_CASES_ = [ _TEST_CASES_ = [
# x_shape, w_shape, y_shape, alpha, beta # x_shape, w_shape, y_shape, alpha, beta
((2, 4), (4, 2), (2, 2)),
((128, 512), (512, 1024), (128, 1024)), ((128, 512), (512, 1024), (128, 1024)),
((256, 1024), (1024, 2048), (256, 2048)), ((256, 1024), (1024, 2048), (256, 2048)),
((1024, 2048), (2048, 1024), (1024, 1024)), ((1024, 2048), (2048, 1024), (1024, 1024)),
...@@ -83,12 +82,16 @@ def test( ...@@ -83,12 +82,16 @@ def test(
sync=None, sync=None,
): ):
print( print(
f"Testing Linear on {InfiniDeviceNames[device]} with x_shape:{x_shape}, w_shape:{w_shape}, inplace:{inplace} dtype:{InfiniDtypeNames[dtype]}" f"Testing scaled_mm_int8 on {InfiniDeviceNames[device]} with x_shape:{x_shape}, w_shape:{w_shape}, inplace:{inplace} dtype:{InfiniDtypeNames[dtype]}"
) )
M, K = x_shape M, K = x_shape
N = w_shape[1] N = w_shape[1]
x_packed = TestTensor( # --- Tensor Descriptor ---
# orig: create a random int8 tensor as the reference data source
# torch: extract the torch view to adjust layout/stride
# final: wrap it back as TestTensor with explicit stride for device execution
x_packed_orig = TestTensor(
(M, K), (M, K),
None, None,
InfiniDtype.I8, InfiniDtype.I8,
...@@ -97,8 +100,18 @@ def test( ...@@ -97,8 +100,18 @@ def test(
randint_low=-128, randint_low=-128,
randint_high=127, randint_high=127,
) )
weights = TestTensor( x_packed_torch = x_packed_orig.torch_tensor()
(K, N), x_packed = TestTensor(
(M, K),
x_packed_torch.stride(),
InfiniDtype.I8,
device,
mode="manual",
set_tensor=x_packed_torch,
)
weights_orig = TestTensor(
(N, K),
None, None,
InfiniDtype.I8, InfiniDtype.I8,
device, device,
...@@ -106,9 +119,44 @@ def test( ...@@ -106,9 +119,44 @@ def test(
randint_low=-128, randint_low=-128,
randint_high=127, randint_high=127,
) )
x_scale = TestTensor((M,), None, InfiniDtype.F32, device, mode="random") weights_torch = weights_orig.torch_tensor().t()
weights_scale = TestTensor((N,), None, InfiniDtype.F32, device, mode="random") weights = TestTensor(
bias = TestTensor((N,), None, dtype, device, mode="random") (K, N),
weights_torch.stride(),
InfiniDtype.I8,
device,
mode="manual",
set_tensor=weights_torch,
)
x_scale_orig = TestTensor((M,), None, InfiniDtype.F32, device, mode="random")
x_scale_torch = x_scale_orig.torch_tensor()
x_scale = TestTensor(
(M,),
x_scale_torch.stride(),
InfiniDtype.F32,
device,
mode="manual",
set_tensor=x_scale_torch,
)
weights_scale_orig = TestTensor((N,), None, InfiniDtype.F32, device, mode="random")
weights_scale_torch = weights_scale_orig.torch_tensor()
weights_scale = TestTensor(
(N,),
weights_scale_torch.stride(),
InfiniDtype.F32,
device,
mode="manual",
set_tensor=weights_scale_torch,
)
bias_orig = TestTensor((N,), None, dtype, device, mode="random")
bias_torch = bias_orig.torch_tensor()
bias = TestTensor(
(N,), bias_torch.stride(), dtype, device, mode="manual", set_tensor=bias_torch
)
y = TestTensor(y_shape, None, dtype, device, mode="zeros") y = TestTensor(y_shape, None, dtype, device, mode="zeros")
ans = torch_scaled_mm( ans = torch_scaled_mm(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment