Commit d9847b6d authored by yuguo's avatar yuguo
Browse files

[DCU] fix

parent 736e8f8b
......@@ -32,6 +32,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_q
channelwise_dequantize_transB,
channelwise_dequantize_transA_add,
channelwise_dequantize_transA_float_add)
from transformer_engine.pytorch.utils import get_device_compute_capability
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
__all__ = [
......@@ -91,7 +92,7 @@ def general_gemm(
)
ref_scales_x = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
if get_device_compute_capability() < (9, 3) or blockwise_fp8_block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
......@@ -111,7 +112,7 @@ def general_gemm(
)
ref_scales_dout = B._rowwise_scale_inv
ref_scales_w = A._columnwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
if get_device_compute_capability() < (9, 3) or blockwise_fp8_block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
......@@ -131,7 +132,7 @@ def general_gemm(
)
ref_scales_dout = B._columnwise_scale_inv
ref_scales_x = A._columnwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
if get_device_compute_capability() < (9, 3) or blockwise_fp8_block_len != 128 or not enable_lightop:
out, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
......@@ -191,7 +192,7 @@ def general_gemm(
x_int8,
transb,
None,
quantization_params,
None,
TE_DType[torch.int32],
bias,
bias_dtype,
......@@ -217,7 +218,7 @@ def general_gemm(
dy_int8,
transb,
None,
quantization_params,
None,
TE_DType[torch.int32],
bias,
bias_dtype,
......@@ -233,6 +234,7 @@ def general_gemm(
return dx, None, None, None
elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
......@@ -242,7 +244,7 @@ def general_gemm(
dy_int8,
transb,
None,
quantization_params,
None,
TE_DType[torch.int32],
bias,
bias_dtype,
......@@ -516,6 +518,7 @@ def general_grouped_gemm(
return out, bias, gelu_input
elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32
qdout_data_list = []
qx_data_list = []
scales_dout_list = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment