"vscode:/vscode.git/clone" did not exist on "ac8d36f3e5776e020eb32a08eb3a2f9c60a49344"
Commit 6c9dc19d authored by wenjh's avatar wenjh
Browse files

Refine the constraints while using lightop in gemm.py


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 59b49b47
...@@ -53,7 +53,7 @@ __all__ = [ ...@@ -53,7 +53,7 @@ __all__ = [
def w8a8_block_int8_matmul_wgrad_batched_native(A_list, B_list, As_list, Bs_list, C_list, accumulate, out_dtype=torch.float16): def w8a8_block_int8_matmul_wgrad_batched_native(A_list, B_list, As_list, Bs_list, C_list, accumulate, out_dtype=torch.float16):
for i in range(len(C_list)): for i in range(len(C_list)):
assert C_list[i] is not None assert C_list[i] is not None
if get_device_compute_capability() >= (9, 3) and blockwise_fp8_block_len == 128 and ((out_dtype is torch.bfloat16) or (out_dtype is torch.float16)): if enable_lightop and get_device_compute_capability() >= (9, 3) and blockwise_fp8_block_len == 128 and ((out_dtype is torch.bfloat16) or (out_dtype is torch.float16)):
C_list[i] = lightop.gemm_w8a8_wgrad_asm( C_list[i] = lightop.gemm_w8a8_wgrad_asm(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, blockwise_fp8_block_len, out_dtype, "TN" A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, blockwise_fp8_block_len, out_dtype, "TN"
) )
...@@ -80,7 +80,7 @@ def w8a8_int8_general_gemm( ...@@ -80,7 +80,7 @@ def w8a8_int8_general_gemm(
qw_data = (A._rowwise_data.view(dtype=torch.int8)) qw_data = (A._rowwise_data.view(dtype=torch.int8))
ref_scales_x = B._rowwise_scale_inv ref_scales_x = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv ref_scales_w = A._rowwise_scale_inv
if get_device_compute_capability() >= (9, 3) and blockwise_fp8_block_len == 128 and ((out_dtype is torch.bfloat16) or (out_dtype is torch.float16)): if enable_lightop and get_device_compute_capability() >= (9, 3) and blockwise_fp8_block_len == 128 and ((out_dtype is torch.bfloat16) or (out_dtype is torch.float16)):
y = lightop.gemm_w8a8_asm(qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN') y = lightop.gemm_w8a8_asm(qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN')
else: else:
warnings.warn("Lightop is not available. Using default implementation for w8a8.") warnings.warn("Lightop is not available. Using default implementation for w8a8.")
...@@ -89,7 +89,7 @@ def w8a8_int8_general_gemm( ...@@ -89,7 +89,7 @@ def w8a8_int8_general_gemm(
elif layout == "NN": elif layout == "NN":
assert accumulate is False, "Accumulate not supported in w8a8_general_gemm with NN layout" assert accumulate is False, "Accumulate not supported in w8a8_general_gemm with NN layout"
assert out is None, "Output tensor not supported in w8a8_general_gemm with NN layout" assert out is None, "Output tensor not supported in w8a8_general_gemm with NN layout"
if get_device_compute_capability() >= (9, 3) and blockwise_fp8_block_len == 128 and ((out_dtype is torch.bfloat16) or (out_dtype is torch.float16)): if enable_lightop and get_device_compute_capability() >= (9, 3) and blockwise_fp8_block_len == 128 and ((out_dtype is torch.bfloat16) or (out_dtype is torch.float16)):
qdout_data = (B._rowwise_data.view(dtype=torch.int8)) qdout_data = (B._rowwise_data.view(dtype=torch.int8))
qw_data = (A._rowwise_data.view(dtype=torch.int8)) qw_data = (A._rowwise_data.view(dtype=torch.int8))
ref_scales_dout = B._rowwise_scale_inv ref_scales_dout = B._rowwise_scale_inv
...@@ -108,7 +108,7 @@ def w8a8_int8_general_gemm( ...@@ -108,7 +108,7 @@ def w8a8_int8_general_gemm(
qx_data = (A._columnwise_data.view(dtype=torch.int8)) qx_data = (A._columnwise_data.view(dtype=torch.int8))
ref_scales_dout = B._columnwise_scale_inv ref_scales_dout = B._columnwise_scale_inv
ref_scales_x = A._columnwise_scale_inv ref_scales_x = A._columnwise_scale_inv
if get_device_compute_capability() >= (9, 3) and blockwise_fp8_block_len == 128 and ((out_dtype is torch.bfloat16) or (out_dtype is torch.float16)): if enable_lightop and get_device_compute_capability() >= (9, 3) and blockwise_fp8_block_len == 128 and ((out_dtype is torch.bfloat16) or (out_dtype is torch.float16)):
out = lightop.gemm_w8a8_wgrad_asm(qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN') out = lightop.gemm_w8a8_wgrad_asm(qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype, 'TN')
else: else:
warnings.warn("Lightop is not available. Using default implementation for w8a8.") warnings.warn("Lightop is not available. Using default implementation for w8a8.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment