Commit 429226fb authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.4'

parents 3b0a1009 1036ccfe
...@@ -2,7 +2,7 @@ import pytest ...@@ -2,7 +2,7 @@ import pytest
import torch import torch
import transformer_engine as te import transformer_engine as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
import w8a8_matmul_extension import lightop
from transformer_engine.pytorch import get_device_compute_capability from transformer_engine.pytorch import get_device_compute_capability
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len) from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
...@@ -203,9 +203,8 @@ def cublas_gemm_fp8_blockwise_case_fw( ...@@ -203,9 +203,8 @@ def cublas_gemm_fp8_blockwise_case_fw(
output_dtype=out_dtype output_dtype=out_dtype
) )
else: else:
y = w8a8_matmul_extension.w8a8_block_int8_matmul( y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], out_dtype
output_dtype=out_dtype
) )
# print("int8 gemm output: ", y) # print("int8 gemm output: ", y)
...@@ -387,9 +386,8 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad( ...@@ -387,9 +386,8 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
output_dtype=dx_dtype output_dtype=dx_dtype
) )
else: else:
y = w8a8_matmul_extension.w8a8_block_int8_matmul( y = lightop.gemm_w8a8_asm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len], qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len], dx_dtype
output_dtype=dx_dtype
) )
# print("int8 gemm dx: ", y) # print("int8 gemm dx: ", y)
...@@ -573,10 +571,9 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad( ...@@ -573,10 +571,9 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
output_dtype=dw_dtype output_dtype=dw_dtype
) )
else: else:
y = w8a8_matmul_extension.w8a8_block_int8_matmul_wgrad( y = lightop.gemm_w8a8_wgrad_asm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None, qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None,
accumulate, [block_len, block_len], accumulate, [block_len, block_len], dw_dtype
output_dtype=dw_dtype
) )
# print("int8 gemm dw: ",y) # print("int8 gemm dw: ",y)
......
...@@ -8,7 +8,7 @@ from typing import Iterable, Optional, Tuple, Union, List ...@@ -8,7 +8,7 @@ from typing import Iterable, Optional, Tuple, Union, List
import os import os
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
import w8a8_matmul_extension import lightop
from ..constants import TE_DType, TE_DType_To_Torch from ..constants import TE_DType, TE_DType_To_Torch
from ..utils import get_sm_count, _empty_tensor from ..utils import get_sm_count, _empty_tensor
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched
...@@ -92,9 +92,8 @@ def general_gemm( ...@@ -92,9 +92,8 @@ def general_gemm(
output_dtype=out_dtype output_dtype=out_dtype
) )
else: else:
y = w8a8_matmul_extension.w8a8_block_int8_matmul( y = lightop.gemm_w8a8_asm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype
output_dtype=out_dtype
) )
return y, None, None, None return y, None, None, None
...@@ -113,9 +112,8 @@ def general_gemm( ...@@ -113,9 +112,8 @@ def general_gemm(
output_dtype=out_dtype output_dtype=out_dtype
) )
else: else:
y = w8a8_matmul_extension.w8a8_block_int8_matmul( y = lightop.gemm_w8a8_asm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype
output_dtype=out_dtype
) )
return y, None, None, None return y, None, None, None
...@@ -134,9 +132,8 @@ def general_gemm( ...@@ -134,9 +132,8 @@ def general_gemm(
output_dtype=out_dtype output_dtype=out_dtype
) )
else: else:
out = w8a8_matmul_extension.w8a8_block_int8_matmul_wgrad( out = lightop.gemm_w8a8_wgrad_asm(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], out_dtype
output_dtype=out_dtype
) )
return out, None, None, None return out, None, None, None
......
...@@ -11,7 +11,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h ...@@ -11,7 +11,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools import functools
import logging import logging
import w8a8_matmul_extension import lightop
from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import get_device_compute_capability
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -617,9 +617,8 @@ def apply_w8a8_block_int8_linear_helper(m: int, ...@@ -617,9 +617,8 @@ def apply_w8a8_block_int8_linear_helper(m: int,
best_config=best_config best_config=best_config
) )
else: else:
output = w8a8_matmul_extension.w8a8_block_int8_matmul( output = lightop.gemm_w8a8_asm(
q_input, weight, x_scale, weight_scale, block_size, q_input, weight, x_scale, weight_scale, block_size, out_dtype
output_dtype=out_dtype
) )
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2): if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
......
...@@ -11,7 +11,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h ...@@ -11,7 +11,7 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools import functools
import logging import logging
import w8a8_matmul_extension import lightop
from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import get_device_compute_capability
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -471,9 +471,8 @@ def w8a8_block_int8_matmul_wgrad_batched_native( ...@@ -471,9 +471,8 @@ def w8a8_block_int8_matmul_wgrad_batched_native(
best_config=best_config best_config=best_config
) )
else: else:
C_list[i] = w8a8_matmul_extension.w8a8_block_int8_matmul_wgrad( C_list[i] = lightop.gemm_w8a8_wgrad_asm(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size, A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size, output_dtype
output_dtype=output_dtype
) )
return C_list return C_list
...@@ -671,9 +670,8 @@ def apply_w8a8_block_int8_linear_helper(m: int, ...@@ -671,9 +670,8 @@ def apply_w8a8_block_int8_linear_helper(m: int,
best_config=best_config best_config=best_config
) )
else: else:
output = w8a8_matmul_extension.w8a8_block_int8_matmul_wgrad( output = lightop.gemm_w8a8_wgrad_asm(
q_input, weight, x_scale, weight_scale, output, False, block_size, q_input, weight, x_scale, weight_scale, output, False, block_size, out_dtype
output_dtype=out_dtype
) )
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2): if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
...@@ -836,7 +834,7 @@ def main(): ...@@ -836,7 +834,7 @@ def main():
block_size=[blockwise_fp8_block_len, blockwise_fp8_block_len] block_size=[blockwise_fp8_block_len, blockwise_fp8_block_len]
out_dtype=torch.bfloat16 out_dtype=torch.float16
_n=[] _n=[]
_k=[] _k=[]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment