Commit 9ab6cd98 authored by yuguo's avatar yuguo
Browse files
parents 782f6092 84e8ce2f
...@@ -5,10 +5,9 @@ import transformer_engine_torch as tex ...@@ -5,10 +5,9 @@ import transformer_engine_torch as tex
import warnings import warnings
try: try:
import lightop import lightop
enable_lightop = True
except ImportError: except ImportError:
enable_lightop = False pass
from transformer_engine.pytorch import get_device_compute_capability from transformer_engine.pytorch.utils import use_lightop_w8a8
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len) from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
...@@ -202,15 +201,15 @@ def cublas_gemm_fp8_blockwise_case_fw( ...@@ -202,15 +201,15 @@ def cublas_gemm_fp8_blockwise_case_fw(
ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop: if use_lightop_w8a8([block_len, block_len]):
y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], out_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul( y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype output_dtype=out_dtype
) )
else:
y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], out_dtype
)
# print("int8 gemm output: ", y) # print("int8 gemm output: ", y)
# print("int8 gemm output shape: ", y.shape) # print("int8 gemm output shape: ", y.shape)
...@@ -385,15 +384,15 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad( ...@@ -385,15 +384,15 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
ref_scales_dout = qdout._columnwise_scale_inv if dout_columnwise else qdout._rowwise_scale_inv ref_scales_dout = qdout._columnwise_scale_inv if dout_columnwise else qdout._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop: if use_lightop_w8a8([block_len, block_len]):
y = lightop.gemm_w8a8_asm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len], dx_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul( y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len], qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len],
output_dtype=dx_dtype output_dtype=dx_dtype
) )
else:
y = lightop.gemm_w8a8_asm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len], dx_dtype
)
# print("int8 gemm dx: ", y) # print("int8 gemm dx: ", y)
...@@ -569,16 +568,16 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad( ...@@ -569,16 +568,16 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}") # print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}") # print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop: if use_lightop_w8a8([block_len, block_len]):
y, _ = w8a8_block_int8_matmul_wgrad( y = lightop.gemm_w8a8_wgrad_asm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None, qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None,
accumulate, [block_len, block_len], accumulate, [block_len, block_len], dw_dtype, 'TN'
output_dtype=dw_dtype
) )
else: else:
y = lightop.gemm_w8a8_wgrad_asm( y, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None, qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None,
accumulate, [block_len, block_len], dw_dtype accumulate, [block_len, block_len],
output_dtype=dw_dtype
) )
# print("int8 gemm dw: ",y) # print("int8 gemm dw: ",y)
......
...@@ -8,12 +8,10 @@ from typing import Iterable, Optional, Tuple, Union, List ...@@ -8,12 +8,10 @@ from typing import Iterable, Optional, Tuple, Union, List
import os import os
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
import warnings
try: try:
import lightop import lightop
enable_lightop = True
except ImportError: except ImportError:
enable_lightop = False pass
from ..constants import TE_DType, TE_DType_To_Torch from ..constants import TE_DType, TE_DType_To_Torch
from ..utils import get_sm_count, _empty_tensor from ..utils import get_sm_count, _empty_tensor
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched
...@@ -23,6 +21,7 @@ from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTens ...@@ -23,6 +21,7 @@ from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTens
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer from ...debug.pytorch.debug_quantization import DebugQuantizer
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
from transformer_engine.pytorch.utils import use_lightop_w8a8
from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_quant_fp8_to_int8, from transformer_engine.pytorch.triton.per_token_group_quant import (per_token_quant_fp8_to_int8,
per_token_quant_fp8_to_int8_v2, per_token_quant_fp8_to_int8_v2,
per_token_quant_fp8_to_int8_opt, per_token_quant_fp8_to_int8_opt,
...@@ -92,15 +91,15 @@ def general_gemm( ...@@ -92,15 +91,15 @@ def general_gemm(
) )
ref_scales_x = B._rowwise_scale_inv ref_scales_x = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv ref_scales_w = A._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or blockwise_fp8_block_len != 128 or not enable_lightop: if use_lightop_w8a8([blockwise_fp8_block_len, blockwise_fp8_block_len]):
y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul( y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype output_dtype=out_dtype
) )
else:
y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype
)
return y, None, None, None return y, None, None, None
elif layout == "NN": elif layout == "NN":
...@@ -112,15 +111,15 @@ def general_gemm( ...@@ -112,15 +111,15 @@ def general_gemm(
) )
ref_scales_dout = B._rowwise_scale_inv ref_scales_dout = B._rowwise_scale_inv
ref_scales_w = A._columnwise_scale_inv ref_scales_w = A._columnwise_scale_inv
if get_device_compute_capability() < (9, 3) or blockwise_fp8_block_len != 128 or not enable_lightop: if use_lightop_w8a8([blockwise_fp8_block_len, blockwise_fp8_block_len]):
y = lightop.gemm_w8a8_asm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN'
)
else:
y, _ = w8a8_block_int8_matmul( y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype output_dtype=out_dtype
) )
else:
y = lightop.gemm_w8a8_asm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype
)
return y, None, None, None return y, None, None, None
elif layout == "NT": elif layout == "NT":
...@@ -132,15 +131,15 @@ def general_gemm( ...@@ -132,15 +131,15 @@ def general_gemm(
) )
ref_scales_dout = B._columnwise_scale_inv ref_scales_dout = B._columnwise_scale_inv
ref_scales_x = A._columnwise_scale_inv ref_scales_x = A._columnwise_scale_inv
if get_device_compute_capability() < (9, 3) or blockwise_fp8_block_len != 128 or not enable_lightop: if use_lightop_w8a8([blockwise_fp8_block_len, blockwise_fp8_block_len]):
out = lightop.gemm_w8a8_wgrad_asm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN'
)
else:
out, _ = w8a8_block_int8_matmul_wgrad( out, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype output_dtype=out_dtype
) )
else:
out = lightop.gemm_w8a8_wgrad_asm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype
)
return out, None, None, None return out, None, None, None
else: else:
......
...@@ -11,13 +11,11 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h ...@@ -11,13 +11,11 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools import functools
import logging import logging
import warnings
try: try:
import lightop import lightop
enable_lightop = True
except ImportError: except ImportError:
enable_lightop = False pass
from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import use_lightop_w8a8
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -615,24 +613,23 @@ def apply_w8a8_block_int8_linear_helper(m: int, ...@@ -615,24 +613,23 @@ def apply_w8a8_block_int8_linear_helper(m: int,
torch_output = native_w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size, out_dtype) torch_output = native_w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size, out_dtype)
x_scale = x_scale.permute(1, 0).contiguous() x_scale = x_scale.permute(1, 0).contiguous()
if get_device_compute_capability() < (9, 3) or block_size[1] != 128 or not enable_lightop: if use_lightop_w8a8(block_size):
output = lightop.gemm_w8a8_asm(
q_input, weight, x_scale, weight_scale, block_size, out_dtype, 'TN'
)
else:
output,config = w8a8_block_int8_matmul( output,config = w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size, q_input, weight, x_scale, weight_scale, block_size,
output_dtype=out_dtype, output_dtype=out_dtype,
best_config=best_config best_config=best_config
) )
else:
output = lightop.gemm_w8a8_asm(
q_input, weight, x_scale, weight_scale, block_size, out_dtype
)
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2): if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
print("triton 精度检查不合格!!!") print("w8a8_block_int8 精度检查不合格!!!")
else: else:
print("triton 精度检查合格") print("w8a8_block_int8 精度检查合格")
# unit test end # unit test end
if get_device_compute_capability() < (9, 3) or block_size[1] != 128: if not use_lightop_w8a8(block_size):
g = torch.cuda.CUDAGraph() g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g): with torch.cuda.graph(g):
for it in range(1000): for it in range(1000):
...@@ -811,7 +808,7 @@ def main(): ...@@ -811,7 +808,7 @@ def main():
best_config = [] best_config = []
apply_w8a8_block_int8_linear_batched_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config) apply_w8a8_block_int8_linear_batched_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if get_device_compute_capability() < (9, 3) or block_size[1] != 128: if not use_lightop_w8a8(block_size):
output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config) output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
cost_times.append(elapsed_time) cost_times.append(elapsed_time)
...@@ -831,7 +828,7 @@ def main(): ...@@ -831,7 +828,7 @@ def main():
else: else:
apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config) apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if get_device_compute_capability() < (9, 3) or block_size[1] != 128: if not use_lightop_w8a8(block_size):
# 创建一个包含这三个列表的 DataFrame # 创建一个包含这三个列表的 DataFrame
df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes, df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes,
'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k, 'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k,
......
...@@ -11,13 +11,11 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h ...@@ -11,13 +11,11 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools import functools
import logging import logging
import warnings
try: try:
import lightop import lightop
enable_lightop = True
except ImportError: except ImportError:
enable_lightop = False pass
from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import use_lightop_w8a8
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
device_name=torch.cuda.get_device_properties('cuda').name.replace(" ","_") device_name=torch.cuda.get_device_properties('cuda').name.replace(" ","_")
...@@ -469,16 +467,16 @@ def w8a8_block_int8_matmul_wgrad_batched_native( ...@@ -469,16 +467,16 @@ def w8a8_block_int8_matmul_wgrad_batched_native(
): ):
for i in range(len(C_list)): for i in range(len(C_list)):
assert C_list[i] is not None assert C_list[i] is not None
if get_device_compute_capability() < (9, 3) or block_size[1] != 128 or not enable_lightop: if use_lightop_w8a8(block_size):
C_list[i] = lightop.gemm_w8a8_wgrad_asm(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size, output_dtype, 'TN'
)
else:
C_list[i], config = w8a8_block_int8_matmul_wgrad( C_list[i], config = w8a8_block_int8_matmul_wgrad(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size, A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size,
output_dtype=output_dtype, output_dtype=output_dtype,
best_config=best_config best_config=best_config
) )
else:
C_list[i] = lightop.gemm_w8a8_wgrad_asm(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size, output_dtype
)
return C_list return C_list
def w8a8_block_int8_matmul_wgrad_batched( def w8a8_block_int8_matmul_wgrad_batched(
...@@ -667,25 +665,24 @@ def apply_w8a8_block_int8_linear_helper(m: int, ...@@ -667,25 +665,24 @@ def apply_w8a8_block_int8_linear_helper(m: int,
N, K = weight.shape N, K = weight.shape
C_shape = q_input.shape[:-1] + (N,) C_shape = q_input.shape[:-1] + (N,)
output = q_input.new_empty(C_shape, dtype=out_dtype) output = q_input.new_empty(C_shape, dtype=out_dtype)
if use_lightop_w8a8(block_size):
output = lightop.gemm_w8a8_wgrad_asm(
q_input, weight, x_scale, weight_scale, output, False, block_size, out_dtype, 'TN'
)
else:
print(f"zhenggf 转置后传递给triton kernel, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}") print(f"zhenggf 转置后传递给triton kernel, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}")
if get_device_compute_capability() < (9, 3) or block_size[1] != 128 or not enable_lightop:
output,config = w8a8_block_int8_matmul_wgrad( output,config = w8a8_block_int8_matmul_wgrad(
q_input, weight, x_scale, weight_scale, output, False, block_size, q_input, weight, x_scale, weight_scale, output, False, block_size,
output_dtype=out_dtype, output_dtype=out_dtype,
best_config=best_config best_config=best_config
) )
else:
output = lightop.gemm_w8a8_wgrad_asm(
q_input, weight, x_scale, weight_scale, output, False, block_size, out_dtype
)
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2): if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
print("triton 精度检查不合格!!!") print("triton 精度检查不合格!!!")
else: else:
print("triton 精度检查合格") print("triton 精度检查合格")
# unit test end # unit test end
if get_device_compute_capability() < (9, 3) or block_size[1] != 128: if not use_lightop_w8a8(block_size):
g = torch.cuda.CUDAGraph() g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g): with torch.cuda.graph(g):
for it in range(1000): for it in range(1000):
...@@ -865,7 +862,7 @@ def main(): ...@@ -865,7 +862,7 @@ def main():
best_config = [] best_config = []
apply_w8a8_block_int8_linear_batched_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config) apply_w8a8_block_int8_linear_batched_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if get_device_compute_capability() < (9, 3) or block_size[1] != 128: if not use_lightop_w8a8(block_size):
output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config) output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
cost_times.append(elapsed_time) cost_times.append(elapsed_time)
...@@ -885,7 +882,7 @@ def main(): ...@@ -885,7 +882,7 @@ def main():
else: else:
apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config) apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if get_device_compute_capability() < (9, 3) or block_size[1] != 128: if not use_lightop_w8a8:
# 创建一个包含这三个列表的 DataFrame # 创建一个包含这三个列表的 DataFrame
df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes, df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes,
'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k, 'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k,
......
...@@ -10,7 +10,12 @@ import os ...@@ -10,7 +10,12 @@ import os
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from typing import Any, Callable, List, Optional, Sequence, Tuple, Union
import numpy as np import numpy as np
import torch import torch
import warnings
try:
import lightop
enable_lightop = True
except ImportError:
enable_lightop = False
import transformer_engine.pytorch.cpp_extensions as ext import transformer_engine.pytorch.cpp_extensions as ext
from . import torch_version from . import torch_version
from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.utils.cpp_extension import IS_HIP_EXTENSION
...@@ -455,6 +460,19 @@ if IS_HIP_EXTENSION: ...@@ -455,6 +460,19 @@ if IS_HIP_EXTENSION:
import re import re
return (re.search('BW', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) return (re.search('BW', torch.cuda.get_device_name(torch.cuda.current_device())) is not None)
def use_lightop_w8a8(block_size: List[int]) -> bool:
"""Check whether to use lightop for w8a8"""
# Just return False because lightop is not ready now.
return False
if(enable_lightop):
return get_device_compute_capability() >= (9, 3) and block_size[1] == 128
else:
if(get_device_compute_capability() >= (9, 3) and block_size[1] == 128):
warnings.warn(
"Lightop is not available. Using default implementation for w8a8."
)
return False
def is_bf16_compatible() -> None: def is_bf16_compatible() -> None:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit """Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher. check on device compute capability to enforce sm_80 or higher.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment