Commit 3f800f01 authored by wenjh's avatar wenjh
Browse files

Enable lightop w8a8


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 00fcd784
...@@ -6,10 +6,6 @@ import pytest ...@@ -6,10 +6,6 @@ import pytest
import torch import torch
import transformer_engine as te import transformer_engine as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
try:
import lightop
except ImportError:
pass
from transformer_engine.pytorch.utils import use_lightop_w8a8 from transformer_engine.pytorch.utils import use_lightop_w8a8
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len, int8_simulation_fp8) from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len, int8_simulation_fp8)
...@@ -17,12 +13,11 @@ from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( ...@@ -17,12 +13,11 @@ from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer, Float8BlockQuantizer,
Float8BlockwiseQTensor, Float8BlockwiseQTensor,
) )
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad
from references.blockwise_quantizer_reference import CuBLASScaleMunger from references.blockwise_quantizer_reference import CuBLASScaleMunger
from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.cpp_extensions.gemm import w8a8_int8_general_gemm
def fp8_blockwise_gemm_supported() -> bool: def fp8_blockwise_gemm_supported() -> bool:
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
...@@ -174,38 +169,12 @@ def cublas_gemm_fp8_blockwise_case( ...@@ -174,38 +169,12 @@ def cublas_gemm_fp8_blockwise_case(
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
if IS_HIP_EXTENSION and int8_simulation_fp8: if IS_HIP_EXTENSION and int8_simulation_fp8:
if use_lightop_w8a8([block_len, block_len]):
if(not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled): if(not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
y = lightop.gemm_w8a8_asm( y = w8a8_int8_general_gemm(qw, qx, out_dtype, False, "TN", None)
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], out_dtype, 'TN'
)
elif (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
y = lightop.gemm_w8a8_xgrad_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], out_dtype, 'TN'
)
elif (x_columnwise and w_columnwise and is_x_1d_scaled and is_w_1d_scaled):
y = lightop.gemm_w8a8_wgrad_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, out.clone() if accumulate else None, accumulate, [block_len, block_len], out_dtype, 'TN'
)
else:
assert False, "Only fwd, xgrad, and wgrad block scaling supported in int8 simulation mode on ROCm."
else:
if(not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype
)
elif (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled): elif (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
y, _ = w8a8_block_int8_matmul( y = w8a8_int8_general_gemm(qw, qx, out_dtype, False, "NN", None)
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype
)
elif (x_columnwise and w_columnwise and is_x_1d_scaled and is_w_1d_scaled): elif (x_columnwise and w_columnwise and is_x_1d_scaled and is_w_1d_scaled):
y, _ = w8a8_block_int8_matmul_wgrad( y = w8a8_int8_general_gemm(qw, qx, out_dtype, accumulate, "NT", out.clone() if accumulate else None)
qx_data, qw_data, ref_scales_x, ref_scales_w, out.clone() if accumulate else None,
accumulate, [block_len, block_len],
output_dtype=out_dtype
)
else: else:
assert False, "Only fwd, xgrad, and wgrad block scaling supported in int8 simulation mode on ROCm." assert False, "Only fwd, xgrad, and wgrad block scaling supported in int8 simulation mode on ROCm."
else: else:
......
...@@ -2,20 +2,13 @@ import pytest ...@@ -2,20 +2,13 @@ import pytest
import torch import torch
import transformer_engine as te import transformer_engine as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
import warnings
try:
import lightop
except ImportError:
pass
from transformer_engine.pytorch.utils import use_lightop_w8a8
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len) from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer, Float8BlockQuantizer,
Float8BlockwiseQTensor, Float8BlockwiseQTensor,
) )
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul from transformer_engine.pytorch.cpp_extensions.gemm import w8a8_int8_general_gemm
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad
from references.blockwise_quantizer_reference import CuBLASScaleMunger from references.blockwise_quantizer_reference import CuBLASScaleMunger
from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm
...@@ -201,15 +194,7 @@ def cublas_gemm_fp8_blockwise_case_fw( ...@@ -201,15 +194,7 @@ def cublas_gemm_fp8_blockwise_case_fw(
ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
if use_lightop_w8a8([block_len, block_len]): y = w8a8_int8_general_gemm(qw, qx, out_dtype, False, "TN", None)
y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], out_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype
)
# print("int8 gemm output: ", y) # print("int8 gemm output: ", y)
# print("int8 gemm output shape: ", y.shape) # print("int8 gemm output shape: ", y.shape)
...@@ -384,15 +369,7 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad( ...@@ -384,15 +369,7 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
ref_scales_dout = qdout._columnwise_scale_inv if dout_columnwise else qdout._rowwise_scale_inv ref_scales_dout = qdout._columnwise_scale_inv if dout_columnwise else qdout._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
if use_lightop_w8a8([block_len, block_len]): y = w8a8_int8_general_gemm(qw, qdout, dx_dtype, False, "NN", None)
y = lightop.gemm_w8a8_asm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len], dx_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len],
output_dtype=dx_dtype
)
# print("int8 gemm dx: ", y) # print("int8 gemm dx: ", y)
...@@ -568,17 +545,7 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad( ...@@ -568,17 +545,7 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}") # print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}") # print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
if use_lightop_w8a8([block_len, block_len]): y = w8a8_int8_general_gemm(qx, qdout, dw_dtype, accumulate, "NT", dw.clone() if accumulate else None)
y = lightop.gemm_w8a8_wgrad_asm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None,
accumulate, [block_len, block_len], dw_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None,
accumulate, [block_len, block_len],
output_dtype=dw_dtype
)
# print("int8 gemm dw: ",y) # print("int8 gemm dw: ",y)
......
...@@ -8,20 +8,21 @@ from typing import Iterable, Optional, Tuple, Union, List ...@@ -8,20 +8,21 @@ from typing import Iterable, Optional, Tuple, Union, List
import os import os
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
import warnings
try: try:
import lightop import lightop
enable_lightop = True
except ImportError: except ImportError:
pass enable_lightop = False
from ..constants import TE_DType, TE_DType_To_Torch from ..constants import TE_DType, TE_DType_To_Torch
from ..utils import get_sm_count, _empty_tensor from ..utils import get_sm_count, _empty_tensor
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad, w8a8_block_int8_matmul_wgrad_batched, w8a8_block_int8_matmul_wgrad_batched_native from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad, w8a8_block_int8_matmul_wgrad_batched
from ..tensor.quantized_tensor import Quantizer from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
from transformer_engine.pytorch.utils import use_lightop_w8a8
from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_quant_fp8_to_int8, from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_quant_fp8_to_int8,
per_token_quant_fp8_to_int8_v2, per_token_quant_fp8_to_int8_v2,
per_token_quant_fp8_to_int8_opt, per_token_quant_fp8_to_int8_opt,
...@@ -45,6 +46,71 @@ __all__ = [ ...@@ -45,6 +46,71 @@ __all__ = [
"batchgemm", "batchgemm",
] ]
def w8a8_block_int8_matmul_wgrad_batched_native(A_list, B_list, As_list, Bs_list, C_list, accumulate, output_dtype=torch.float16):
for i in range(len(C_list)):
assert C_list[i] is not None
if get_device_compute_capability() >= (9, 3) and blockwise_fp8_block_len == 128:
C_list[i] = lightop.gemm_w8a8_wgrad_asm(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, blockwise_fp8_block_len, output_dtype, "TN"
)
else:
C_list[i], _ = w8a8_block_int8_matmul_wgrad(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, blockwise_fp8_block_len,
output_dtype,
None
)
return C_list
def w8a8_int8_general_gemm(
A: torch.Tensor,
B: torch.Tensor,
out_dtype: Optional[torch.dtype] = None,
accumulate: bool = False,
layout: str = "TN",
out: Optional[torch.Tensor] = None) -> torch.Tensor:
if layout == "TN":
assert accumulate is False, "Accumulate not supported in w8a8_general_gemm with TN layout"
assert out is None, "Output tensor not supported in w8a8_general_gemm with TN layout"
qx_data = (B._rowwise_data.view(dtype=torch.int8))
qw_data = (A._rowwise_data.view(dtype=torch.int8))
ref_scales_x = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv
if get_device_compute_capability() >= (9, 3) and blockwise_fp8_block_len == 128:
y = lightop.gemm_w8a8_asm(qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN')
else:
warnings.warn("Lightop is not available. Using default implementation for w8a8.")
y, _ = w8a8_block_int8_matmul(qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype)
return y
elif layout == "NN":
assert accumulate is False, "Accumulate not supported in w8a8_general_gemm with NN layout"
assert out is None, "Output tensor not supported in w8a8_general_gemm with NN layout"
if get_device_compute_capability() >= (9, 3) and blockwise_fp8_block_len == 128:
qdout_data = (B._rowwise_data.view(dtype=torch.int8))
qw_data = (A._rowwise_data.view(dtype=torch.int8))
ref_scales_dout = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv
y = lightop.gemm_w8a8_xgrad_asm(qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'NN')
else:
warnings.warn("Lightop is not available. Using default implementation for w8a8.")
qdout_data = (B._rowwise_data.view(dtype=torch.int8))
qw_data = (A._columnwise_data.view(dtype=torch.int8))
ref_scales_dout = B._rowwise_scale_inv
ref_scales_w = A._columnwise_scale_inv
y, _ = w8a8_block_int8_matmul(qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],out_dtype)
return y
elif layout == "NT":
qdout_data = (B._columnwise_data.view(dtype=torch.int8))
qx_data = (A._columnwise_data.view(dtype=torch.int8))
ref_scales_dout = B._columnwise_scale_inv
ref_scales_x = A._columnwise_scale_inv
if get_device_compute_capability() >= (9, 3) and blockwise_fp8_block_len == 128:
out = lightop.gemm_w8a8_wgrad_asm(qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN')
else:
warnings.warn("Lightop is not available. Using default implementation for w8a8.")
out, _ = w8a8_block_int8_matmul_wgrad(qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],out_dtype)
return out
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
def validate_gemm_scale(scale: Optional[float], required: bool) -> float: def validate_gemm_scale(scale: Optional[float], required: bool) -> float:
"""Validate whether a GEMM scaling factor is consistent with its usage""" """Validate whether a GEMM scaling factor is consistent with its usage"""
...@@ -92,78 +158,6 @@ def general_gemm( ...@@ -92,78 +158,6 @@ def general_gemm(
# + "a valid `ub` communicator object." # + "a valid `ub` communicator object."
# ) # )
if int8_simulation_fp8 and (isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase)):
assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation"
assert bias is None, "Bias not supported with int8 simulation"
assert ub is None, "User buffer not supported with int8 simulation"
assert ub_type is None, "User buffer type not supported with int8 simulation"
assert extra_output is None, "Extra output not supported with int8 simulation"
assert not bulk_overlap, "Bulk overlap not supported with int8 simulation"
if layout == "TN":
qx_data = (
B._rowwise_data.view(dtype=torch.int8)
)
qw_data = (
A._rowwise_data.view(dtype=torch.int8)
)
ref_scales_x = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv
if use_lightop_w8a8([blockwise_fp8_block_len, blockwise_fp8_block_len]):
y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
return y, None, None, None
elif layout == "NN":
qdout_data = (
B._rowwise_data.view(dtype=torch.int8)
)
qw_data = (
A._columnwise_data.view(dtype=torch.int8)
)
ref_scales_dout = B._rowwise_scale_inv
ref_scales_w = A._columnwise_scale_inv
if use_lightop_w8a8([blockwise_fp8_block_len, blockwise_fp8_block_len]):
y = lightop.gemm_w8a8_asm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
return y, None, None, None
elif layout == "NT":
qdout_data = (
B._columnwise_data.view(dtype=torch.int8)
)
qx_data = (
A._columnwise_data.view(dtype=torch.int8)
)
ref_scales_dout = B._columnwise_scale_inv
ref_scales_x = A._columnwise_scale_inv
if use_lightop_w8a8([blockwise_fp8_block_len, blockwise_fp8_block_len]):
out = lightop.gemm_w8a8_wgrad_asm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN'
)
else:
out, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
return out, None, None, None
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
if ub is not None: if ub is not None:
assert ub_type is not None, "Comm+GEMM overlap requires a valid `comm_type` argument." assert ub_type is not None, "Comm+GEMM overlap requires a valid `comm_type` argument."
if ub_type == tex.CommOverlapType.RS: if ub_type == tex.CommOverlapType.RS:
...@@ -196,6 +190,25 @@ def general_gemm( ...@@ -196,6 +190,25 @@ def general_gemm(
): ):
raise RuntimeError("GEMM with Float8BlockwiseQTensor requires GEMM_READY format") raise RuntimeError("GEMM with Float8BlockwiseQTensor requires GEMM_READY format")
if int8_simulation_fp8 and (isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase)):
assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation"
assert bias is None, "Bias not supported with int8 simulation"
assert ub is None, "User buffer not supported with int8 simulation"
assert ub_type is None, "User buffer type not supported with int8 simulation"
assert extra_output is None, "Extra output not supported with int8 simulation"
assert not bulk_overlap, "Bulk overlap not supported with int8 simulation"
y = w8a8_int8_general_gemm(
A,
B,
out_dtype,
accumulate,
layout,
out
)
return y, None, None, None
if int8_simulation_fp8 and (isinstance(A, Float8TensorBase) or isinstance(B, Float8TensorBase)) and int8_simulation_fp8_tensorwise: if int8_simulation_fp8 and (isinstance(A, Float8TensorBase) or isinstance(B, Float8TensorBase)) and int8_simulation_fp8_tensorwise:
assert not gelu, "GELU not supported with int8 simulation" assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation" assert gelu_in is None, "GELU input not supported with int8 simulation"
...@@ -480,8 +493,7 @@ def general_grouped_gemm( ...@@ -480,8 +493,7 @@ def general_grouped_gemm(
ref_scales_x = [a._columnwise_scale_inv for a in A] ref_scales_x = [a._columnwise_scale_inv for a in A]
out = w8a8_block_int8_matmul_wgrad_batched_native( out = w8a8_block_int8_matmul_wgrad_batched_native(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, out_dtype
output_dtype=out_dtype
) )
return out, bias, gelu_input return out, bias, gelu_input
......
...@@ -11,11 +11,6 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h ...@@ -11,11 +11,6 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools import functools
import logging import logging
try:
import lightop
except ImportError:
pass
from transformer_engine.pytorch.utils import use_lightop_w8a8
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -613,23 +608,18 @@ def apply_w8a8_block_int8_linear_helper(m: int, ...@@ -613,23 +608,18 @@ def apply_w8a8_block_int8_linear_helper(m: int,
torch_output = native_w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size, out_dtype) torch_output = native_w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size, out_dtype)
x_scale = x_scale.permute(1, 0).contiguous() x_scale = x_scale.permute(1, 0).contiguous()
if use_lightop_w8a8(block_size):
output = lightop.gemm_w8a8_asm(
q_input, weight, x_scale, weight_scale, block_size, out_dtype, 'TN'
)
else:
output,config = w8a8_block_int8_matmul( output,config = w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size, q_input, weight, x_scale, weight_scale, block_size,
output_dtype=out_dtype, output_dtype=out_dtype,
best_config=best_config best_config=best_config
) )
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2): if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
print("w8a8_block_int8 精度检查不合格!!!") print("w8a8_block_int8 精度检查不合格!!!")
else: else:
print("w8a8_block_int8 精度检查合格") print("w8a8_block_int8 精度检查合格")
# unit test end # unit test end
if not use_lightop_w8a8(block_size):
g = torch.cuda.CUDAGraph() g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g): with torch.cuda.graph(g):
for it in range(1000): for it in range(1000):
...@@ -653,6 +643,7 @@ def apply_w8a8_block_int8_linear_helper(m: int, ...@@ -653,6 +643,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
output = output + bias output = output + bias
return output.to(dtype=out_dtype),elapsed_time,gpu_costtime,config return output.to(dtype=out_dtype),elapsed_time,gpu_costtime,config
def get_triton_cache(file_path,n,k,block_n,block_k): def get_triton_cache(file_path,n,k,block_n,block_k):
#会将所报错的json文件以字典的形式return出来 #会将所报错的json文件以字典的形式return出来
#先读取指定的文件,该文件地址不存在则会读默认路径 #先读取指定的文件,该文件地址不存在则会读默认路径
...@@ -808,7 +799,6 @@ def main(): ...@@ -808,7 +799,6 @@ def main():
best_config = [] best_config = []
apply_w8a8_block_int8_linear_batched_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config) apply_w8a8_block_int8_linear_batched_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if not use_lightop_w8a8(block_size):
output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config) output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
cost_times.append(elapsed_time) cost_times.append(elapsed_time)
...@@ -825,10 +815,7 @@ def main(): ...@@ -825,10 +815,7 @@ def main():
_configs_block_num_warps.append(config.num_warps) _configs_block_num_warps.append(config.num_warps)
_configs_block_num_stages.append(config.num_stages) _configs_block_num_stages.append(config.num_stages)
# _configs_kpack.append(config['kpack']) # _configs_kpack.append(config['kpack'])
else:
apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if not use_lightop_w8a8(block_size):
# 创建一个包含这三个列表的 DataFrame # 创建一个包含这三个列表的 DataFrame
df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes, df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes,
'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k, 'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k,
...@@ -839,6 +826,7 @@ def main(): ...@@ -839,6 +826,7 @@ def main():
print("表格已保存到 gemmoutput.xlsx 文件中。") print("表格已保存到 gemmoutput.xlsx 文件中。")
if __name__ == "__main__": if __name__ == "__main__":
main() main()
...@@ -11,11 +11,6 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h ...@@ -11,11 +11,6 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools import functools
import logging import logging
try:
import lightop
except ImportError:
pass
from transformer_engine.pytorch.utils import use_lightop_w8a8
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
device_name=torch.cuda.get_device_properties('cuda').name.replace(" ","_") device_name=torch.cuda.get_device_properties('cuda').name.replace(" ","_")
...@@ -461,24 +456,6 @@ def w8a8_block_int8_matmul_wgrad( ...@@ -461,24 +456,6 @@ def w8a8_block_int8_matmul_wgrad(
return C,config return C,config
def w8a8_block_int8_matmul_wgrad_batched_native(
A_list, B_list, As_list, Bs_list, C_list, accumulate,
block_size, output_dtype=torch.float16, best_config=None
):
for i in range(len(C_list)):
assert C_list[i] is not None
if use_lightop_w8a8(block_size):
C_list[i] = lightop.gemm_w8a8_wgrad_asm(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size, output_dtype, 'TN'
)
else:
C_list[i], config = w8a8_block_int8_matmul_wgrad(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size,
output_dtype=output_dtype,
best_config=best_config
)
return C_list
def w8a8_block_int8_matmul_wgrad_batched( def w8a8_block_int8_matmul_wgrad_batched(
A_list, B_list, As_list, Bs_list, C_list, accumulate, A_list, B_list, As_list, Bs_list, C_list, accumulate,
block_size, output_dtype=torch.float16, best_config=None block_size, output_dtype=torch.float16, best_config=None
...@@ -665,24 +642,19 @@ def apply_w8a8_block_int8_linear_helper(m: int, ...@@ -665,24 +642,19 @@ def apply_w8a8_block_int8_linear_helper(m: int,
N, K = weight.shape N, K = weight.shape
C_shape = q_input.shape[:-1] + (N,) C_shape = q_input.shape[:-1] + (N,)
output = q_input.new_empty(C_shape, dtype=out_dtype) output = q_input.new_empty(C_shape, dtype=out_dtype)
if use_lightop_w8a8(block_size):
output = lightop.gemm_w8a8_wgrad_asm(
q_input, weight, x_scale, weight_scale, output, False, block_size, out_dtype, 'TN'
)
else:
print(f"zhenggf 转置后传递给triton kernel, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}") print(f"zhenggf 转置后传递给triton kernel, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}")
output,config = w8a8_block_int8_matmul_wgrad( output,config = w8a8_block_int8_matmul_wgrad(
q_input, weight, x_scale, weight_scale, output, False, block_size, q_input, weight, x_scale, weight_scale, output, False, block_size,
output_dtype=out_dtype, output_dtype=out_dtype,
best_config=best_config best_config=best_config
) )
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2): if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
print("triton 精度检查不合格!!!") print("triton 精度检查不合格!!!")
else: else:
print("triton 精度检查合格") print("triton 精度检查合格")
# unit test end # unit test end
if not use_lightop_w8a8(block_size):
g = torch.cuda.CUDAGraph() g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g): with torch.cuda.graph(g):
for it in range(1000): for it in range(1000):
...@@ -706,6 +678,7 @@ def apply_w8a8_block_int8_linear_helper(m: int, ...@@ -706,6 +678,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
output = output + bias output = output + bias
return output.to(dtype=out_dtype),elapsed_time,gpu_costtime,config return output.to(dtype=out_dtype),elapsed_time,gpu_costtime,config
def get_triton_cache(file_path,n,k,block_n,block_k): def get_triton_cache(file_path,n,k,block_n,block_k):
#会将所报错的json文件以字典的形式return出来 #会将所报错的json文件以字典的形式return出来
#先读取指定的文件,该文件地址不存在则会读默认路径 #先读取指定的文件,该文件地址不存在则会读默认路径
...@@ -862,7 +835,6 @@ def main(): ...@@ -862,7 +835,6 @@ def main():
best_config = [] best_config = []
apply_w8a8_block_int8_linear_batched_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config) apply_w8a8_block_int8_linear_batched_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if not use_lightop_w8a8(block_size):
output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config) output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
cost_times.append(elapsed_time) cost_times.append(elapsed_time)
...@@ -879,10 +851,7 @@ def main(): ...@@ -879,10 +851,7 @@ def main():
_configs_block_num_warps.append(config.num_warps) _configs_block_num_warps.append(config.num_warps)
_configs_block_num_stages.append(config.num_stages) _configs_block_num_stages.append(config.num_stages)
# _configs_kpack.append(config['kpack']) # _configs_kpack.append(config['kpack'])
else:
apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if not use_lightop_w8a8:
# 创建一个包含这三个列表的 DataFrame # 创建一个包含这三个列表的 DataFrame
df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes, df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes,
'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k, 'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k,
......
...@@ -10,12 +10,6 @@ import os ...@@ -10,12 +10,6 @@ import os
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import warnings
try:
import lightop
enable_lightop = True
except ImportError:
enable_lightop = False
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
from . import torch_version from . import torch_version
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
...@@ -460,18 +454,6 @@ if IS_HIP_EXTENSION: ...@@ -460,18 +454,6 @@ if IS_HIP_EXTENSION:
import re import re
return (re.search('BW', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) return (re.search('BW', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
def use_lightop_w8a8(block_size: List[int]) -> bool:
"""Check whether to use lightop for w8a8"""
# Just return False because lightop is not ready now.
return False
if(enable_lightop):
return get_device_compute_capability() >= (9, 3) and block_size[1] == 128
else:
if(get_device_compute_capability() >= (9, 3) and block_size[1] == 128):
warnings.warn(
"Lightop is not available. Using default implementation for w8a8."
)
return False
def is_bf16_compatible() -> None: def is_bf16_compatible() -> None:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit """Replaces torch.cuda.is_bf16_compatible() with an explicit
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment