"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "992ba01d4aacdd2a59e8f22d7c04d23a6c020752"
Commit 9ab6cd98 authored by yuguo's avatar yuguo
Browse files
parents 782f6092 84e8ce2f
......@@ -5,10 +5,9 @@ import transformer_engine_torch as tex
import warnings
try:
import lightop
enable_lightop = True
except ImportError:
enable_lightop = False
from transformer_engine.pytorch import get_device_compute_capability
pass
from transformer_engine.pytorch.utils import use_lightop_w8a8
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
......@@ -202,15 +201,15 @@ def cublas_gemm_fp8_blockwise_case_fw(
ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
if use_lightop_w8a8([block_len, block_len]):
y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], out_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype
)
else:
y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], out_dtype
)
# print("int8 gemm output: ", y)
# print("int8 gemm output shape: ", y.shape)
......@@ -385,15 +384,15 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
ref_scales_dout = qdout._columnwise_scale_inv if dout_columnwise else qdout._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
if use_lightop_w8a8([block_len, block_len]):
y = lightop.gemm_w8a8_asm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len], dx_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len],
output_dtype=dx_dtype
)
else:
y = lightop.gemm_w8a8_asm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len], dx_dtype
)
# print("int8 gemm dx: ", y)
......@@ -569,16 +568,16 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul_wgrad(
if use_lightop_w8a8([block_len, block_len]):
y = lightop.gemm_w8a8_wgrad_asm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None,
accumulate, [block_len, block_len],
output_dtype=dw_dtype
accumulate, [block_len, block_len], dw_dtype, 'TN'
)
else:
y = lightop.gemm_w8a8_wgrad_asm(
y, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None,
accumulate, [block_len, block_len], dw_dtype
accumulate, [block_len, block_len],
output_dtype=dw_dtype
)
# print("int8 gemm dw: ",y)
......
......@@ -8,12 +8,10 @@ from typing import Iterable, Optional, Tuple, Union, List
import os
import torch
import transformer_engine_torch as tex
import warnings
try:
import lightop
enable_lightop = True
except ImportError:
enable_lightop = False
pass
from ..constants import TE_DType, TE_DType_To_Torch
from ..utils import get_sm_count, _empty_tensor
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched
......@@ -23,6 +21,7 @@ from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTens
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
from transformer_engine.pytorch.utils import use_lightop_w8a8
from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_quant_fp8_to_int8,
per_token_quant_fp8_to_int8_v2,
per_token_quant_fp8_to_int8_opt,
......@@ -92,15 +91,15 @@ def general_gemm(
)
ref_scales_x = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or blockwise_fp8_block_len != 128 or not enable_lightop:
if use_lightop_w8a8([blockwise_fp8_block_len, blockwise_fp8_block_len]):
y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
else:
y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype
)
return y, None, None, None
elif layout == "NN":
......@@ -112,15 +111,15 @@ def general_gemm(
)
ref_scales_dout = B._rowwise_scale_inv
ref_scales_w = A._columnwise_scale_inv
if get_device_compute_capability() < (9, 3) or blockwise_fp8_block_len != 128 or not enable_lightop:
if use_lightop_w8a8([blockwise_fp8_block_len, blockwise_fp8_block_len]):
y = lightop.gemm_w8a8_asm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
else:
y = lightop.gemm_w8a8_asm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype
)
return y, None, None, None
elif layout == "NT":
......@@ -132,15 +131,15 @@ def general_gemm(
)
ref_scales_dout = B._columnwise_scale_inv
ref_scales_x = A._columnwise_scale_inv
if get_device_compute_capability() < (9, 3) or blockwise_fp8_block_len != 128 or not enable_lightop:
if use_lightop_w8a8([blockwise_fp8_block_len, blockwise_fp8_block_len]):
out = lightop.gemm_w8a8_wgrad_asm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN'
)
else:
out, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
else:
out = lightop.gemm_w8a8_wgrad_asm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype
)
return out, None, None, None
else:
......
......@@ -11,13 +11,11 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools
import logging
import warnings
try:
import lightop
enable_lightop = True
except ImportError:
enable_lightop = False
from transformer_engine.pytorch.utils import get_device_compute_capability
pass
from transformer_engine.pytorch.utils import use_lightop_w8a8
logger = logging.getLogger(__name__)
......@@ -615,24 +613,23 @@ def apply_w8a8_block_int8_linear_helper(m: int,
torch_output = native_w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size, out_dtype)
x_scale = x_scale.permute(1, 0).contiguous()
if get_device_compute_capability() < (9, 3) or block_size[1] != 128 or not enable_lightop:
if use_lightop_w8a8(block_size):
output = lightop.gemm_w8a8_asm(
q_input, weight, x_scale, weight_scale, block_size, out_dtype, 'TN'
)
else:
output,config = w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size,
output_dtype=out_dtype,
best_config=best_config
)
else:
output = lightop.gemm_w8a8_asm(
q_input, weight, x_scale, weight_scale, block_size, out_dtype
)
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
print("triton 精度检查不合格!!!")
print("w8a8_block_int8 精度检查不合格!!!")
else:
print("triton 精度检查合格")
print("w8a8_block_int8 精度检查合格")
# unit test end
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
if not use_lightop_w8a8(block_size):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for it in range(1000):
......@@ -811,7 +808,7 @@ def main():
best_config = []
apply_w8a8_block_int8_linear_batched_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
if not use_lightop_w8a8(block_size):
output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
cost_times.append(elapsed_time)
......@@ -831,7 +828,7 @@ def main():
else:
apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
if not use_lightop_w8a8(block_size):
# 创建一个包含这三个列表的 DataFrame
df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes,
'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k,
......
......@@ -11,13 +11,11 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools
import logging
import warnings
try:
import lightop
enable_lightop = True
except ImportError:
enable_lightop = False
from transformer_engine.pytorch.utils import get_device_compute_capability
pass
from transformer_engine.pytorch.utils import use_lightop_w8a8
logger = logging.getLogger(__name__)
device_name=torch.cuda.get_device_properties('cuda').name.replace(" ","_")
......@@ -469,16 +467,16 @@ def w8a8_block_int8_matmul_wgrad_batched_native(
):
for i in range(len(C_list)):
assert C_list[i] is not None
if get_device_compute_capability() < (9, 3) or block_size[1] != 128 or not enable_lightop:
if use_lightop_w8a8(block_size):
C_list[i] = lightop.gemm_w8a8_wgrad_asm(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size, output_dtype, 'TN'
)
else:
C_list[i], config = w8a8_block_int8_matmul_wgrad(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size,
output_dtype=output_dtype,
best_config=best_config
)
else:
C_list[i] = lightop.gemm_w8a8_wgrad_asm(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size, output_dtype
)
return C_list
def w8a8_block_int8_matmul_wgrad_batched(
......@@ -667,25 +665,24 @@ def apply_w8a8_block_int8_linear_helper(m: int,
N, K = weight.shape
C_shape = q_input.shape[:-1] + (N,)
output = q_input.new_empty(C_shape, dtype=out_dtype)
print(f"zhenggf 转置后传递给triton kernel, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}")
if get_device_compute_capability() < (9, 3) or block_size[1] != 128 or not enable_lightop:
if use_lightop_w8a8(block_size):
output = lightop.gemm_w8a8_wgrad_asm(
q_input, weight, x_scale, weight_scale, output, False, block_size, out_dtype, 'TN'
)
else:
print(f"zhenggf 转置后传递给triton kernel, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}")
output,config = w8a8_block_int8_matmul_wgrad(
q_input, weight, x_scale, weight_scale, output, False, block_size,
output_dtype=out_dtype,
best_config=best_config
)
else:
output = lightop.gemm_w8a8_wgrad_asm(
q_input, weight, x_scale, weight_scale, output, False, block_size, out_dtype
)
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
print("triton 精度检查不合格!!!")
else:
print("triton 精度检查合格")
# unit test end
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
if not use_lightop_w8a8(block_size):
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for it in range(1000):
......@@ -865,7 +862,7 @@ def main():
best_config = []
apply_w8a8_block_int8_linear_batched_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
if not use_lightop_w8a8(block_size):
output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
cost_times.append(elapsed_time)
......@@ -885,7 +882,7 @@ def main():
else:
apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
if not use_lightop_w8a8:
# 创建一个包含这三个列表的 DataFrame
df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes,
'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k,
......
......@@ -10,7 +10,12 @@ import os
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np
import torch
import warnings
try:
import lightop
enable_lightop = True
except ImportError:
enable_lightop = False
import transformer_engine.pytorch.cpp_extensions as ext
from . import torch_version
from torch.utils.cpp_extension import IS_HIP_EXTENSION
......@@ -454,6 +459,19 @@ if IS_HIP_EXTENSION:
"""check whether this machine is BW"""
import re
return (re.search('BW', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
def use_lightop_w8a8(block_size: List[int]) -> bool:
"""Check whether to use lightop for w8a8"""
# Just return False because lightop is not ready now.
return False
if(enable_lightop):
return get_device_compute_capability() >= (9, 3) and block_size[1] == 128
else:
if(get_device_compute_capability() >= (9, 3) and block_size[1] == 128):
warnings.warn(
"Lightop is not available. Using default implementation for w8a8."
)
return False
def is_bf16_compatible() -> None:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment