Commit 8113d9e0 authored by yuguo's avatar yuguo
Browse files
parents 93ecbc82 d9847b6d
...@@ -32,6 +32,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_q ...@@ -32,6 +32,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_q
channelwise_dequantize_transB, channelwise_dequantize_transB,
channelwise_dequantize_transA_add, channelwise_dequantize_transA_add,
channelwise_dequantize_transA_float_add) channelwise_dequantize_transA_float_add)
from transformer_engine.pytorch.utils import get_device_compute_capability
int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0"))) int8_simulation_fp8 = bool(int(os.getenv("NVTE_INT8_SIM_FP8", "0")))
__all__ = [ __all__ = [
...@@ -91,7 +92,7 @@ def general_gemm( ...@@ -91,7 +92,7 @@ def general_gemm(
) )
ref_scales_x = B._rowwise_scale_inv ref_scales_x = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv ref_scales_w = A._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop: if get_device_compute_capability() < (9, 3) or blockwise_fp8_block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul( y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype output_dtype=out_dtype
...@@ -111,7 +112,7 @@ def general_gemm( ...@@ -111,7 +112,7 @@ def general_gemm(
) )
ref_scales_dout = B._rowwise_scale_inv ref_scales_dout = B._rowwise_scale_inv
ref_scales_w = A._columnwise_scale_inv ref_scales_w = A._columnwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop: if get_device_compute_capability() < (9, 3) or blockwise_fp8_block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul( y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype output_dtype=out_dtype
...@@ -131,7 +132,7 @@ def general_gemm( ...@@ -131,7 +132,7 @@ def general_gemm(
) )
ref_scales_dout = B._columnwise_scale_inv ref_scales_dout = B._columnwise_scale_inv
ref_scales_x = A._columnwise_scale_inv ref_scales_x = A._columnwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop: if get_device_compute_capability() < (9, 3) or blockwise_fp8_block_len != 128 or not enable_lightop:
out, _ = w8a8_block_int8_matmul_wgrad( out, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype output_dtype=out_dtype
...@@ -199,7 +200,7 @@ def general_gemm( ...@@ -199,7 +200,7 @@ def general_gemm(
x_int8, x_int8,
transb, transb,
None, None,
quantization_params, None,
TE_DType[torch.int32], TE_DType[torch.int32],
bias, bias,
bias_dtype, bias_dtype,
...@@ -225,7 +226,7 @@ def general_gemm( ...@@ -225,7 +226,7 @@ def general_gemm(
dy_int8, dy_int8,
transb, transb,
None, None,
quantization_params, None,
TE_DType[torch.int32], TE_DType[torch.int32],
bias, bias,
bias_dtype, bias_dtype,
...@@ -241,6 +242,7 @@ def general_gemm( ...@@ -241,6 +242,7 @@ def general_gemm(
return dx, None, None, None return dx, None, None, None
elif layout == "NT": elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32
dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False) dy_int8, dy_scales = per_token_quant_fp8_to_int8_opt(B._data.view(dtype=TE_DType_To_Torch[B._fp8_dtype]), B._scale_inv, False)
x_int8, x_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False) x_int8, x_scales = per_token_quant_fp8_to_int8_opt(A._data.view(dtype=TE_DType_To_Torch[A._fp8_dtype]), A._scale_inv, False)
...@@ -250,7 +252,7 @@ def general_gemm( ...@@ -250,7 +252,7 @@ def general_gemm(
dy_int8, dy_int8,
transb, transb,
None, None,
quantization_params, None,
TE_DType[torch.int32], TE_DType[torch.int32],
bias, bias,
bias_dtype, bias_dtype,
...@@ -524,6 +526,7 @@ def general_grouped_gemm( ...@@ -524,6 +526,7 @@ def general_grouped_gemm(
return out, bias, gelu_input return out, bias, gelu_input
elif layout == "NT": elif layout == "NT":
assert out_dtype is torch.bfloat16 or out_dtype is torch.float32
qdout_data_list = [] qdout_data_list = []
qx_data_list = [] qx_data_list = []
scales_dout_list = [] scales_dout_list = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment