"tests/vscode:/vscode.git/clone" did not exist on "490a5f41ada5788bc6dd94ba54ab024e465e0ec6"
Commit 8113d9e0 authored by yuguo's avatar yuguo
Browse files
parents 93ecbc82 d9847b6d
......@@ -32,6 +32,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_q
channelwise_dequantize_transB,
channelwise_dequantize_transA_add,
channelwise_dequantize_transA_float_add)
from transformer_engine.pytorch.utils import get_device_compute_capability
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
__all__ = [
......@@ -91,7 +92,7 @@ def general_gemm(
)
ref_scales_x = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
if get_device_compute_capability() < (9, 3) or blockwise_fp8_block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
......@@ -111,7 +112,7 @@ def general_gemm(
)
ref_scales_dout = B._rowwise_scale_inv
ref_scales_w = A._columnwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
if get_device_compute_capability() < (9, 3) or blockwise_fp8_block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
......@@ -131,7 +132,7 @@ def general_gemm(
)
ref_scales_dout = B._columnwise_scale_inv
ref_scales_x = A._columnwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
if get_device_compute_capability() < (9, 3) or blockwise_fp8_block_len != 128 or not enable_lightop:
out, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
......@@ -199,7 +200,7 @@ def general_gemm(
x_int8,
transb,
None,
quantization_params,
None,
TE_DType[torch.int32],
bias,
bias_dtype,
......@@ -225,7 +226,7 @@ def general_gemm(
dy_int8,
transb,
None,
quantization_params,
None,
TE_DType[torch.int32],
bias,
bias_dtype,
......@@ -241,6 +242,7 @@ def general_gemm(
return dx, None, None, None
elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
......@@ -250,7 +252,7 @@ def general_gemm(
dy_int8,
transb,
None,
quantization_params,
None,
TE_DType[torch.int32],
bias,
bias_dtype,
......@@ -524,6 +526,7 @@ def general_grouped_gemm(
return out, bias, gelu_input
elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32
qdout_data_list = []
qx_data_list = []
scales_dout_list = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment