Commit ca1e98b6 authored by wenjh's avatar wenjh
Browse files

Fix float8 blockwise gemm tests with accumulator


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 065160ab
......@@ -56,7 +56,10 @@ def cublas_gemm_fp8_blockwise_case(
if use_bias or use_gelu:
pytest.skip("Bias and GELU not supported in int8 simulation mode on ROCm.")
if not ((not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled) or (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled) or (x_columnwise and w_columnwise and is_x_1d_scaled and is_w_1d_scaled)):
pytest.skip("Only 1Dx2D, 1Dx1D, and 2Dx1D block scaling supported in int8 simulation mode on ROCm.")
pytest.skip("Only fwd, xgrad, and wgrad block scaling supported in int8 simulation mode on ROCm.")
if((not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled) or (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled)):
if accumulate:
pytest.skip("Accumulation not supported in fwd and xgrad block scaling in int8 simulation mode on ROCm.")
if x_dtype == torch.float8_e5m2 and w_dtype == torch.float8_e5m2:
pytest.skip("FP8 GEMM doesn't support both a and b types being torch.float8_e5m2")
if not (is_x_1d_scaled or is_w_1d_scaled):
......@@ -185,7 +188,7 @@ def cublas_gemm_fp8_blockwise_case(
qx_data, qw_data, ref_scales_x, ref_scales_w, out.clone() if accumulate else None, accumulate, [block_len, block_len], out_dtype, 'TN'
)
else:
assert False, "Only 1Dx2D, 1Dx1D, and 2Dx1D block scaling supported in int8 simulation mode on ROCm."
assert False, "Only fwd, xgrad, and wgrad block scaling supported in int8 simulation mode on ROCm."
else:
if(not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
y, _ = w8a8_block_int8_matmul(
......@@ -204,7 +207,7 @@ def cublas_gemm_fp8_blockwise_case(
output_dtype=out_dtype
)
else:
assert False, "Only 1Dx2D, 1Dx1D, and 2Dx1D block scaling supported in int8 simulation mode on ROCm."
assert False, "Only fwd, xgrad, and wgrad block scaling supported in int8 simulation mode on ROCm."
else:
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment