Commit 3f800f01 authored by wenjh's avatar wenjh
Browse files

Enable lightop w8a8


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 00fcd784
......@@ -6,10 +6,6 @@ import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
try:
import lightop
except ImportError:
pass
from transformer_engine.pytorch.utils import use_lightop_w8a8
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len, int8_simulation_fp8)
......@@ -17,12 +13,11 @@ from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
)
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad
from references.blockwise_quantizer_reference import CuBLASScaleMunger
from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from transformer_engine.pytorch.cpp_extensions.gemm import w8a8_int8_general_gemm
def fp8_blockwise_gemm_supported() -> bool:
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
......@@ -174,40 +169,14 @@ def cublas_gemm_fp8_blockwise_case(
bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype]
if IS_HIP_EXTENSION and int8_simulation_fp8:
if use_lightop_w8a8([block_len, block_len]):
if(not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], out_dtype, 'TN'
)
elif (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
y = lightop.gemm_w8a8_xgrad_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], out_dtype, 'TN'
)
elif (x_columnwise and w_columnwise and is_x_1d_scaled and is_w_1d_scaled):
y = lightop.gemm_w8a8_wgrad_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, out.clone() if accumulate else None, accumulate, [block_len, block_len], out_dtype, 'TN'
)
else:
assert False, "Only fwd, xgrad, and wgrad block scaling supported in int8 simulation mode on ROCm."
if(not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
y = w8a8_int8_general_gemm(qw, qx, out_dtype, False, "TN", None)
elif (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
y = w8a8_int8_general_gemm(qw, qx, out_dtype, False, "NN", None)
elif (x_columnwise and w_columnwise and is_x_1d_scaled and is_w_1d_scaled):
y = w8a8_int8_general_gemm(qw, qx, out_dtype, accumulate, "NT", out.clone() if accumulate else None)
else:
if(not x_columnwise and not w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype
)
elif (not x_columnwise and w_columnwise and is_x_1d_scaled and not is_w_1d_scaled):
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype
)
elif (x_columnwise and w_columnwise and is_x_1d_scaled and is_w_1d_scaled):
y, _ = w8a8_block_int8_matmul_wgrad(
qx_data, qw_data, ref_scales_x, ref_scales_w, out.clone() if accumulate else None,
accumulate, [block_len, block_len],
output_dtype=out_dtype
)
else:
assert False, "Only fwd, xgrad, and wgrad block scaling supported in int8 simulation mode on ROCm."
assert False, "Only fwd, xgrad, and wgrad block scaling supported in int8 simulation mode on ROCm."
else:
# cuBLAS GEMM
# return type is out, bias_grad, gelu_input, extra_output
......
......@@ -2,20 +2,13 @@ import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
import warnings
try:
import lightop
except ImportError:
pass
from transformer_engine.pytorch.utils import use_lightop_w8a8
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
)
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad
from transformer_engine.pytorch.cpp_extensions.gemm import w8a8_int8_general_gemm
from references.blockwise_quantizer_reference import CuBLASScaleMunger
from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm
......@@ -201,15 +194,7 @@ def cublas_gemm_fp8_blockwise_case_fw(
ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
if use_lightop_w8a8([block_len, block_len]):
y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], out_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype
)
y = w8a8_int8_general_gemm(qw, qx, out_dtype, False, "TN", None)
# print("int8 gemm output: ", y)
# print("int8 gemm output shape: ", y.shape)
......@@ -384,15 +369,7 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
ref_scales_dout = qdout._columnwise_scale_inv if dout_columnwise else qdout._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
if use_lightop_w8a8([block_len, block_len]):
y = lightop.gemm_w8a8_asm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len], dx_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len],
output_dtype=dx_dtype
)
y = w8a8_int8_general_gemm(qw, qdout, dx_dtype, False, "NN", None)
# print("int8 gemm dx: ", y)
......@@ -568,17 +545,7 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
if use_lightop_w8a8([block_len, block_len]):
y = lightop.gemm_w8a8_wgrad_asm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None,
accumulate, [block_len, block_len], dw_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None,
accumulate, [block_len, block_len],
output_dtype=dw_dtype
)
y = w8a8_int8_general_gemm(qx, qdout, dw_dtype, accumulate, "NT", dw.clone() if accumulate else None)
# print("int8 gemm dw: ",y)
......
......@@ -8,20 +8,21 @@ from typing import Iterable, Optional, Tuple, Union, List
import os
import torch
import transformer_engine_torch as tex
import warnings
try:
import lightop
enable_lightop = True
except ImportError:
pass
enable_lightop = False
from ..constants import TE_DType, TE_DType_To_Torch
from ..utils import get_sm_count, _empty_tensor
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad, w8a8_block_int8_matmul_wgrad_batched, w8a8_block_int8_matmul_wgrad_batched_native
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad, w8a8_block_int8_matmul_wgrad_batched
from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
from transformer_engine.pytorch.utils import use_lightop_w8a8
from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_quant_fp8_to_int8,
per_token_quant_fp8_to_int8_v2,
per_token_quant_fp8_to_int8_opt,
......@@ -45,6 +46,71 @@ __all__ = [
"batchgemm",
]
def w8a8_block_int8_matmul_wgrad_batched_native(A_list, B_list, As_list, Bs_list, C_list, accumulate, output_dtype=torch.float16):
for i in range(len(C_list)):
assert C_list[i] is not None
if get_device_compute_capability() >= (9, 3) and blockwise_fp8_block_len == 128:
C_list[i] = lightop.gemm_w8a8_wgrad_asm(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, blockwise_fp8_block_len, output_dtype, "TN"
)
else:
C_list[i], _ = w8a8_block_int8_matmul_wgrad(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, blockwise_fp8_block_len,
output_dtype,
None
)
return C_list
def w8a8_int8_general_gemm(
A: torch.Tensor,
B: torch.Tensor,
out_dtype: Optional[torch.dtype] = None,
accumulate: bool = False,
layout: str = "TN",
out: Optional[torch.Tensor] = None) -> torch.Tensor:
if layout == "TN":
assert accumulate is False, "Accumulate not supported in w8a8_general_gemm with TN layout"
assert out is None, "Output tensor not supported in w8a8_general_gemm with TN layout"
qx_data = (B._rowwise_data.view(dtype=torch.int8))
qw_data = (A._rowwise_data.view(dtype=torch.int8))
ref_scales_x = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv
if get_device_compute_capability() >= (9, 3) and blockwise_fp8_block_len == 128:
y = lightop.gemm_w8a8_asm(qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN')
else:
warnings.warn("Lightop is not available. Using default implementation for w8a8.")
y, _ = w8a8_block_int8_matmul(qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype)
return y
elif layout == "NN":
assert accumulate is False, "Accumulate not supported in w8a8_general_gemm with NN layout"
assert out is None, "Output tensor not supported in w8a8_general_gemm with NN layout"
if get_device_compute_capability() >= (9, 3) and blockwise_fp8_block_len == 128:
qdout_data = (B._rowwise_data.view(dtype=torch.int8))
qw_data = (A._rowwise_data.view(dtype=torch.int8))
ref_scales_dout = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv
y = lightop.gemm_w8a8_xgrad_asm(qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'NN')
else:
warnings.warn("Lightop is not available. Using default implementation for w8a8.")
qdout_data = (B._rowwise_data.view(dtype=torch.int8))
qw_data = (A._columnwise_data.view(dtype=torch.int8))
ref_scales_dout = B._rowwise_scale_inv
ref_scales_w = A._columnwise_scale_inv
y, _ = w8a8_block_int8_matmul(qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],out_dtype)
return y
elif layout == "NT":
qdout_data = (B._columnwise_data.view(dtype=torch.int8))
qx_data = (A._columnwise_data.view(dtype=torch.int8))
ref_scales_dout = B._columnwise_scale_inv
ref_scales_x = A._columnwise_scale_inv
if get_device_compute_capability() >= (9, 3) and blockwise_fp8_block_len == 128:
out = lightop.gemm_w8a8_wgrad_asm(qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN')
else:
warnings.warn("Lightop is not available. Using default implementation for w8a8.")
out, _ = w8a8_block_int8_matmul_wgrad(qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],out_dtype)
return out
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
def validate_gemm_scale(scale: Optional[float], required: bool) -> float:
"""Validate whether a GEMM scaling factor is consistent with its usage"""
......@@ -92,78 +158,6 @@ def general_gemm(
# + "a valid `ub` communicator object."
# )
if int8_simulation_fp8 and (isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase)):
assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation"
assert bias is None, "Bias not supported with int8 simulation"
assert ub is None, "User buffer not supported with int8 simulation"
assert ub_type is None, "User buffer type not supported with int8 simulation"
assert extra_output is None, "Extra output not supported with int8 simulation"
assert not bulk_overlap, "Bulk overlap not supported with int8 simulation"
if layout == "TN":
qx_data = (
B._rowwise_data.view(dtype=torch.int8)
)
qw_data = (
A._rowwise_data.view(dtype=torch.int8)
)
ref_scales_x = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv
if use_lightop_w8a8([blockwise_fp8_block_len, blockwise_fp8_block_len]):
y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
return y, None, None, None
elif layout == "NN":
qdout_data = (
B._rowwise_data.view(dtype=torch.int8)
)
qw_data = (
A._columnwise_data.view(dtype=torch.int8)
)
ref_scales_dout = B._rowwise_scale_inv
ref_scales_w = A._columnwise_scale_inv
if use_lightop_w8a8([blockwise_fp8_block_len, blockwise_fp8_block_len]):
y = lightop.gemm_w8a8_asm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
return y, None, None, None
elif layout == "NT":
qdout_data = (
B._columnwise_data.view(dtype=torch.int8)
)
qx_data = (
A._columnwise_data.view(dtype=torch.int8)
)
ref_scales_dout = B._columnwise_scale_inv
ref_scales_x = A._columnwise_scale_inv
if use_lightop_w8a8([blockwise_fp8_block_len, blockwise_fp8_block_len]):
out = lightop.gemm_w8a8_wgrad_asm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN'
)
else:
out, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
return out, None, None, None
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
if ub is not None:
assert ub_type is not None, "Comm+GEMM overlap requires a valid `comm_type` argument."
if ub_type == tex.CommOverlapType.RS:
......@@ -195,6 +189,25 @@ def general_gemm(
or B._data_format != tex.Float8BlockScaleTensorFormat.GEMM_READY
):
raise RuntimeError("GEMM with Float8BlockwiseQTensor requires GEMM_READY format")
if int8_simulation_fp8 and (isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase)):
assert not gelu, "GELU not supported with int8 simulation"
assert gelu_in is None, "GELU input not supported with int8 simulation"
assert bias is None, "Bias not supported with int8 simulation"
assert ub is None, "User buffer not supported with int8 simulation"
assert ub_type is None, "User buffer type not supported with int8 simulation"
assert extra_output is None, "Extra output not supported with int8 simulation"
assert not bulk_overlap, "Bulk overlap not supported with int8 simulation"
y = w8a8_int8_general_gemm(
A,
B,
out_dtype,
accumulate,
layout,
out
)
return y, None, None, None
if int8_simulation_fp8 and (isinstance(A, Float8TensorBase) or isinstance(B, Float8TensorBase)) and int8_simulation_fp8_tensorwise:
assert not gelu, "GELU not supported with int8 simulation"
......@@ -480,8 +493,7 @@ def general_grouped_gemm(
ref_scales_x = [a._columnwise_scale_inv for a in A]
out = w8a8_block_int8_matmul_wgrad_batched_native(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, out_dtype
)
return out, bias, gelu_input
......
......@@ -11,11 +11,6 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools
import logging
try:
import lightop
except ImportError:
pass
from transformer_engine.pytorch.utils import use_lightop_w8a8
logger = logging.getLogger(__name__)
......@@ -613,45 +608,41 @@ def apply_w8a8_block_int8_linear_helper(m: int,
torch_output = native_w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size, out_dtype)
x_scale = x_scale.permute(1, 0).contiguous()
if use_lightop_w8a8(block_size):
output = lightop.gemm_w8a8_asm(
q_input, weight, x_scale, weight_scale, block_size, out_dtype, 'TN'
)
else:
output,config = w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size,
output_dtype=out_dtype,
best_config=best_config
)
output,config = w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size,
output_dtype=out_dtype,
best_config=best_config
)
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
print("w8a8_block_int8 精度检查不合格!!!")
else:
print("w8a8_block_int8 精度检查合格")
# unit test end
if not use_lightop_w8a8(block_size):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for it in range(1000):
output,_ = w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size,
output_dtype=out_dtype,
best_config=best_config
)
torch.cuda.synchronize()
start_time_ = time.time() # 开始计时
g.replay()
torch.cuda.synchronize()
end_time_ = time.time() # 结束计时
elapsed_time = round((end_time_ - start_time_) *1000 ,7)# 计算耗时
print("_time:{} us\n".format(elapsed_time))
quantiles = [0.5, 0.2, 0.8]
gpu_costtime = triton.testing.do_bench(lambda:w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size,output_dtype=out_dtype,best_config=best_config),quantiles=None, return_mode="mean")*1000
if bias is not None:
output = output + bias
return output.to(dtype=out_dtype),elapsed_time,gpu_costtime,config
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for it in range(1000):
output,_ = w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size,
output_dtype=out_dtype,
best_config=best_config
)
torch.cuda.synchronize()
start_time_ = time.time() # 开始计时
g.replay()
torch.cuda.synchronize()
end_time_ = time.time() # 结束计时
elapsed_time = round((end_time_ - start_time_) *1000 ,7)# 计算耗时
print("_time:{} us\n".format(elapsed_time))
quantiles = [0.5, 0.2, 0.8]
gpu_costtime = triton.testing.do_bench(lambda:w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size,output_dtype=out_dtype,best_config=best_config),quantiles=None, return_mode="mean")*1000
if bias is not None:
output = output + bias
return output.to(dtype=out_dtype),elapsed_time,gpu_costtime,config
def get_triton_cache(file_path,n,k,block_n,block_k):
#会将所报错的json文件以字典的形式return出来
......@@ -808,36 +799,33 @@ def main():
best_config = []
apply_w8a8_block_int8_linear_batched_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if not use_lightop_w8a8(block_size):
output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
cost_times.append(elapsed_time)
gpu_costtimes.append(gpu_costtime)
_n.append(n_list[i])
_k.append(k_list[i])
_m.append(m)
print(f"zhenggf, {config}")
print(f"zhenggf, {config.kwargs}")
_configs_block_m.append(config.kwargs['BLOCK_SIZE_M'])
_configs_block_n.append(config.kwargs['BLOCK_SIZE_N'])
_configs_block_k.append(config.kwargs['BLOCK_SIZE_K'])
_configs_block_group_m.append(config.kwargs['GROUP_SIZE_M'])
_configs_block_num_warps.append(config.num_warps)
_configs_block_num_stages.append(config.num_stages)
# _configs_kpack.append(config['kpack'])
else:
apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
cost_times.append(elapsed_time)
gpu_costtimes.append(gpu_costtime)
_n.append(n_list[i])
_k.append(k_list[i])
_m.append(m)
print(f"zhenggf, {config}")
print(f"zhenggf, {config.kwargs}")
_configs_block_m.append(config.kwargs['BLOCK_SIZE_M'])
_configs_block_n.append(config.kwargs['BLOCK_SIZE_N'])
_configs_block_k.append(config.kwargs['BLOCK_SIZE_K'])
_configs_block_group_m.append(config.kwargs['GROUP_SIZE_M'])
_configs_block_num_warps.append(config.num_warps)
_configs_block_num_stages.append(config.num_stages)
# _configs_kpack.append(config['kpack'])
if not use_lightop_w8a8(block_size):
# 创建一个包含这三个列表的 DataFrame
df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes,
'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k,
'GROUP_SIZE_M':_configs_block_group_m,'num_warps':_configs_block_num_warps,'num_stages':_configs_block_num_stages,#'kpack':_configs_kpack
})
# 将 DataFrame 写入 Excel 文件
df.to_excel('gemmoutput.xlsx', index=False)
print("表格已保存到 gemmoutput.xlsx 文件中。")
# 创建一个包含这三个列表的 DataFrame
df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes,
'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k,
'GROUP_SIZE_M':_configs_block_group_m,'num_warps':_configs_block_num_warps,'num_stages':_configs_block_num_stages,#'kpack':_configs_kpack
})
# 将 DataFrame 写入 Excel 文件
df.to_excel('gemmoutput.xlsx', index=False)
print("表格已保存到 gemmoutput.xlsx 文件中。")
if __name__ == "__main__":
main()
......
......@@ -11,11 +11,6 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools
import logging
try:
import lightop
except ImportError:
pass
from transformer_engine.pytorch.utils import use_lightop_w8a8
logger = logging.getLogger(__name__)
device_name=torch.cuda.get_device_properties('cuda').name.replace(" ","_")
......@@ -461,24 +456,6 @@ def w8a8_block_int8_matmul_wgrad(
return C,config
def w8a8_block_int8_matmul_wgrad_batched_native(
A_list, B_list, As_list, Bs_list, C_list, accumulate,
block_size, output_dtype=torch.float16, best_config=None
):
for i in range(len(C_list)):
assert C_list[i] is not None
if use_lightop_w8a8(block_size):
C_list[i] = lightop.gemm_w8a8_wgrad_asm(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size, output_dtype, 'TN'
)
else:
C_list[i], config = w8a8_block_int8_matmul_wgrad(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size,
output_dtype=output_dtype,
best_config=best_config
)
return C_list
def w8a8_block_int8_matmul_wgrad_batched(
A_list, B_list, As_list, Bs_list, C_list, accumulate,
block_size, output_dtype=torch.float16, best_config=None
......@@ -665,46 +642,42 @@ def apply_w8a8_block_int8_linear_helper(m: int,
N, K = weight.shape
C_shape = q_input.shape[:-1] + (N,)
output = q_input.new_empty(C_shape, dtype=out_dtype)
if use_lightop_w8a8(block_size):
output = lightop.gemm_w8a8_wgrad_asm(
q_input, weight, x_scale, weight_scale, output, False, block_size, out_dtype, 'TN'
)
else:
print(f"zhenggf 转置后传递给triton kernel, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}")
output,config = w8a8_block_int8_matmul_wgrad(
q_input, weight, x_scale, weight_scale, output, False, block_size,
output_dtype=out_dtype,
best_config=best_config
)
print(f"zhenggf 转置后传递给triton kernel, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}")
output,config = w8a8_block_int8_matmul_wgrad(
q_input, weight, x_scale, weight_scale, output, False, block_size,
output_dtype=out_dtype,
best_config=best_config
)
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
print("triton 精度检查不合格!!!")
else:
print("triton 精度检查合格")
# unit test end
if not use_lightop_w8a8(block_size):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for it in range(1000):
output,_ = w8a8_block_int8_matmul_wgrad(
q_input, weight, x_scale, weight_scale, output, False, block_size,
output_dtype=out_dtype,
best_config=best_config
)
torch.cuda.synchronize()
start_time_ = time.time() # 开始计时
g.replay()
torch.cuda.synchronize()
end_time_ = time.time() # 结束计时
elapsed_time = round((end_time_ - start_time_) *1000 ,7)# 计算耗时
print("_time:{} us\n".format(elapsed_time))
quantiles = [0.5, 0.2, 0.8]
gpu_costtime = triton.testing.do_bench(lambda:w8a8_block_int8_matmul_wgrad(q_input, weight, x_scale, weight_scale, output, False, block_size,output_dtype=out_dtype,best_config=best_config),quantiles=None, return_mode="mean")*1000
if bias is not None:
output = output + bias
return output.to(dtype=out_dtype),elapsed_time,gpu_costtime,config
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for it in range(1000):
output,_ = w8a8_block_int8_matmul_wgrad(
q_input, weight, x_scale, weight_scale, output, False, block_size,
output_dtype=out_dtype,
best_config=best_config
)
torch.cuda.synchronize()
start_time_ = time.time() # 开始计时
g.replay()
torch.cuda.synchronize()
end_time_ = time.time() # 结束计时
elapsed_time = round((end_time_ - start_time_) *1000 ,7)# 计算耗时
print("_time:{} us\n".format(elapsed_time))
quantiles = [0.5, 0.2, 0.8]
gpu_costtime = triton.testing.do_bench(lambda:w8a8_block_int8_matmul_wgrad(q_input, weight, x_scale, weight_scale, output, False, block_size,output_dtype=out_dtype,best_config=best_config),quantiles=None, return_mode="mean")*1000
if bias is not None:
output = output + bias
return output.to(dtype=out_dtype),elapsed_time,gpu_costtime,config
def get_triton_cache(file_path,n,k,block_n,block_k):
#会将所报错的json文件以字典的形式return出来
......@@ -862,36 +835,32 @@ def main():
best_config = []
apply_w8a8_block_int8_linear_batched_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if not use_lightop_w8a8(block_size):
output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
cost_times.append(elapsed_time)
gpu_costtimes.append(gpu_costtime)
_n.append(n_list[i])
_k.append(k_list[i])
_m.append(m)
print(f"zhenggf, {config}")
print(f"zhenggf, {config.kwargs}")
_configs_block_m.append(config.kwargs['BLOCK_SIZE_M'])
_configs_block_n.append(config.kwargs['BLOCK_SIZE_N'])
_configs_block_k.append(config.kwargs['BLOCK_SIZE_K'])
_configs_block_group_m.append(config.kwargs['GROUP_SIZE_M'])
_configs_block_num_warps.append(config.num_warps)
_configs_block_num_stages.append(config.num_stages)
# _configs_kpack.append(config['kpack'])
else:
apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
cost_times.append(elapsed_time)
gpu_costtimes.append(gpu_costtime)
_n.append(n_list[i])
_k.append(k_list[i])
_m.append(m)
print(f"zhenggf, {config}")
print(f"zhenggf, {config.kwargs}")
_configs_block_m.append(config.kwargs['BLOCK_SIZE_M'])
_configs_block_n.append(config.kwargs['BLOCK_SIZE_N'])
_configs_block_k.append(config.kwargs['BLOCK_SIZE_K'])
_configs_block_group_m.append(config.kwargs['GROUP_SIZE_M'])
_configs_block_num_warps.append(config.num_warps)
_configs_block_num_stages.append(config.num_stages)
# _configs_kpack.append(config['kpack'])
if not use_lightop_w8a8:
# 创建一个包含这三个列表的 DataFrame
df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes,
'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k,
'GROUP_SIZE_M':_configs_block_group_m,'num_warps':_configs_block_num_warps,'num_stages':_configs_block_num_stages,#'kpack':_configs_kpack
})
# 将 DataFrame 写入 Excel 文件
df.to_excel('gemmoutput.xlsx', index=False)
print("表格已保存到 gemmoutput.xlsx 文件中。")
# 创建一个包含这三个列表的 DataFrame
df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes,
'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k,
'GROUP_SIZE_M':_configs_block_group_m,'num_warps':_configs_block_num_warps,'num_stages':_configs_block_num_stages,#'kpack':_configs_kpack
})
# 将 DataFrame 写入 Excel 文件
df.to_excel('gemmoutput.xlsx', index=False)
print("表格已保存到 gemmoutput.xlsx 文件中。")
if __name__ == "__main__":
main()
......
......@@ -10,12 +10,6 @@ import os
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import warnings
try:
import lightop
enable_lightop = True
except ImportError:
enable_lightop = False
import transformer_engine.pytorch.cpp_extensions as ext
from . import torch_version
from torch.utils.cpp_extension import IS_HIP_EXTENSION
......@@ -460,18 +454,6 @@ if IS_HIP_EXTENSION:
import re
return (re.search('BW', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
def use_lightop_w8a8(block_size: List[int]) -> bool:
"""Check whether to use lightop for w8a8"""
# Just return False because lightop is not ready now.
return False
if(enable_lightop):
return get_device_compute_capability() >= (9, 3) and block_size[1] == 128
else:
if(get_device_compute_capability() >= (9, 3) and block_size[1] == 128):
warnings.warn(
"Lightop is not available. Using default implementation for w8a8."
)
return False
def is_bf16_compatible() -> None:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment