Commit 1036ccfe authored by wenjh's avatar wenjh
Browse files

Use lightop replace w8a8_mutmal_extension


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 00738a42
......@@ -2,7 +2,7 @@ import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
import w8a8_matmul_extension
import lightop
from transformer_engine.pytorch import get_device_compute_capability
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
......@@ -203,9 +203,8 @@ def cublas_gemm_fp8_blockwise_case_fw(
output_dtype=out_dtype
)
else:
y = w8a8_matmul_extension.w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype
y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], out_dtype
)
# print("int8 gemm output: ", y)
......@@ -387,9 +386,8 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
output_dtype=dx_dtype
)
else:
y = w8a8_matmul_extension.w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len],
output_dtype=dx_dtype
y = lightop.gemm_w8a8_asm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len], dx_dtype
)
# print("int8 gemm dx: ", y)
......@@ -573,10 +571,9 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
output_dtype=dw_dtype
)
else:
y = w8a8_matmul_extension.w8a8_block_int8_matmul_wgrad(
y = lightop.gemm_w8a8_wgrad_asm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None,
accumulate, [block_len, block_len],
output_dtype=dw_dtype
accumulate, [block_len, block_len], dw_dtype
)
# print("int8 gemm dw: ",y)
......
......@@ -8,7 +8,7 @@ from typing import Iterable, Optional, Tuple, Union, List
import os
import torch
import transformer_engine_torch as tex
import w8a8_matmul_extension
import lightop
from ..constants import TE_DType, TE_DType_To_Torch
from ..utils import get_sm_count, _empty_tensor
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched
......@@ -92,9 +92,8 @@ def general_gemm(
output_dtype=out_dtype
)
else:
y = w8a8_matmul_extension.w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype
)
return y, None, None, None
......@@ -113,9 +112,8 @@ def general_gemm(
output_dtype=out_dtype
)
else:
y = w8a8_matmul_extension.w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
y = lightop.gemm_w8a8_asm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype
)
return y, None, None, None
......@@ -134,9 +132,8 @@ def general_gemm(
output_dtype=out_dtype
)
else:
out = w8a8_matmul_extension.w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
out = lightop.gemm_w8a8_wgrad_asm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype
)
return out, None, None, None
......
......@@ -11,7 +11,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools
import logging
import w8a8_matmul_extension
import lightop
from transformer_engine.pytorch.utils import get_device_compute_capability
logger = logging.getLogger(__name__)
......@@ -617,9 +617,8 @@ def apply_w8a8_block_int8_linear_helper(m: int,
best_config=best_config
)
else:
output = w8a8_matmul_extension.w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size,
output_dtype=out_dtype
output = lightop.gemm_w8a8_asm(
q_input, weight, x_scale, weight_scale, block_size, out_dtype
)
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
......
......@@ -11,7 +11,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools
import logging
import w8a8_matmul_extension
import lightop
from transformer_engine.pytorch.utils import get_device_compute_capability
logger = logging.getLogger(__name__)
......@@ -471,9 +471,8 @@ def w8a8_block_int8_matmul_wgrad_batched_native(
best_config=best_config
)
else:
C_list[i] = w8a8_matmul_extension.w8a8_block_int8_matmul_wgrad(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size,
output_dtype=output_dtype
C_list[i] = lightop.gemm_w8a8_wgrad_asm(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size, output_dtype
)
return C_list
......@@ -671,9 +670,8 @@ def apply_w8a8_block_int8_linear_helper(m: int,
best_config=best_config
)
else:
output = w8a8_matmul_extension.w8a8_block_int8_matmul_wgrad(
q_input, weight, x_scale, weight_scale, output, False, block_size,
output_dtype=out_dtype
output = lightop.gemm_w8a8_wgrad_asm(
q_input, weight, x_scale, weight_scale, output, False, block_size, out_dtype
)
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
......@@ -836,7 +834,7 @@ def main():
block_size=[blockwise_fp8_block_len, blockwise_fp8_block_len]
out_dtype=torch.bfloat16
out_dtype=torch.float16
_n=[]
_k=[]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment