"git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "10cceae9cc7b0bc453565a76f87602b1c824ea19"
Commit bb8cf71b authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.4'

parents 429226fb f5349823
...@@ -2,7 +2,12 @@ import pytest ...@@ -2,7 +2,12 @@ import pytest
import torch import torch
import transformer_engine as te import transformer_engine as te
import transformer_engine_torch as tex import transformer_engine_torch as tex
import lightop import warnings
try:
import lightop
enable_lightop = True
except ImportError:
enable_lightop = False
from transformer_engine.pytorch import get_device_compute_capability from transformer_engine.pytorch import get_device_compute_capability
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len) from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
...@@ -197,7 +202,7 @@ def cublas_gemm_fp8_blockwise_case_fw( ...@@ -197,7 +202,7 @@ def cublas_gemm_fp8_blockwise_case_fw(
ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128: if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul( y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len], qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype output_dtype=out_dtype
...@@ -380,7 +385,7 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad( ...@@ -380,7 +385,7 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
ref_scales_dout = qdout._columnwise_scale_inv if dout_columnwise else qdout._rowwise_scale_inv ref_scales_dout = qdout._columnwise_scale_inv if dout_columnwise else qdout._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128: if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul( y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len], qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len],
output_dtype=dx_dtype output_dtype=dx_dtype
...@@ -564,7 +569,7 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad( ...@@ -564,7 +569,7 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}") # print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}") # print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
if get_device_compute_capability() < (9, 3) or block_len != 128: if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul_wgrad( y, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None, qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None,
accumulate, [block_len, block_len], accumulate, [block_len, block_len],
......
...@@ -8,7 +8,12 @@ from typing import Iterable, Optional, Tuple, Union, List ...@@ -8,7 +8,12 @@ from typing import Iterable, Optional, Tuple, Union, List
import os import os
import torch import torch
import transformer_engine_torch as tex import transformer_engine_torch as tex
import lightop import warnings
try:
import lightop
enable_lightop = True
except ImportError:
enable_lightop = False
from ..constants import TE_DType, TE_DType_To_Torch from ..constants import TE_DType, TE_DType_To_Torch
from ..utils import get_sm_count, _empty_tensor from ..utils import get_sm_count, _empty_tensor
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched
...@@ -86,7 +91,7 @@ def general_gemm( ...@@ -86,7 +91,7 @@ def general_gemm(
) )
ref_scales_x = B._rowwise_scale_inv ref_scales_x = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv ref_scales_w = A._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128: if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul( y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype output_dtype=out_dtype
...@@ -106,7 +111,7 @@ def general_gemm( ...@@ -106,7 +111,7 @@ def general_gemm(
) )
ref_scales_dout = B._rowwise_scale_inv ref_scales_dout = B._rowwise_scale_inv
ref_scales_w = A._columnwise_scale_inv ref_scales_w = A._columnwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128: if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul( y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len], qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype output_dtype=out_dtype
...@@ -126,7 +131,7 @@ def general_gemm( ...@@ -126,7 +131,7 @@ def general_gemm(
) )
ref_scales_dout = B._columnwise_scale_inv ref_scales_dout = B._columnwise_scale_inv
ref_scales_x = A._columnwise_scale_inv ref_scales_x = A._columnwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128: if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
out, _ = w8a8_block_int8_matmul_wgrad( out, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len], qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype output_dtype=out_dtype
......
...@@ -11,7 +11,12 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h ...@@ -11,7 +11,12 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools import functools
import logging import logging
import lightop import warnings
try:
import lightop
enable_lightop = True
except ImportError:
enable_lightop = False
from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import get_device_compute_capability
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -610,7 +615,7 @@ def apply_w8a8_block_int8_linear_helper(m: int, ...@@ -610,7 +615,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
torch_output = native_w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size, out_dtype) torch_output = native_w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size, out_dtype)
x_scale = x_scale.permute(1, 0).contiguous() x_scale = x_scale.permute(1, 0).contiguous()
if get_device_compute_capability() < (9, 3) or block_size[1] != 128: if get_device_compute_capability() < (9, 3) or block_size[1] != 128 or not enable_lightop:
output,config = w8a8_block_int8_matmul( output,config = w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size, q_input, weight, x_scale, weight_scale, block_size,
output_dtype=out_dtype, output_dtype=out_dtype,
......
...@@ -11,7 +11,12 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h ...@@ -11,7 +11,12 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools import functools
import logging import logging
import lightop import warnings
try:
import lightop
enable_lightop = True
except ImportError:
enable_lightop = False
from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import get_device_compute_capability
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -464,7 +469,7 @@ def w8a8_block_int8_matmul_wgrad_batched_native( ...@@ -464,7 +469,7 @@ def w8a8_block_int8_matmul_wgrad_batched_native(
): ):
for i in range(len(C_list)): for i in range(len(C_list)):
assert C_list[i] is not None assert C_list[i] is not None
if get_device_compute_capability() < (9, 3) or block_size[1] != 128: if get_device_compute_capability() < (9, 3) or block_size[1] != 128 or not enable_lightop:
C_list[i], config = w8a8_block_int8_matmul_wgrad( C_list[i], config = w8a8_block_int8_matmul_wgrad(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size, A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size,
output_dtype=output_dtype, output_dtype=output_dtype,
...@@ -663,7 +668,7 @@ def apply_w8a8_block_int8_linear_helper(m: int, ...@@ -663,7 +668,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
C_shape = q_input.shape[:-1] + (N,) C_shape = q_input.shape[:-1] + (N,)
output = q_input.new_empty(C_shape, dtype=out_dtype) output = q_input.new_empty(C_shape, dtype=out_dtype)
print(f"zhenggf 转置后传递给triton kernel, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}") print(f"zhenggf 转置后传递给triton kernel, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}")
if get_device_compute_capability() < (9, 3) or block_size[1] != 128: if get_device_compute_capability() < (9, 3) or block_size[1] != 128 or not enable_lightop:
output,config = w8a8_block_int8_matmul_wgrad( output,config = w8a8_block_int8_matmul_wgrad(
q_input, weight, x_scale, weight_scale, output, False, block_size, q_input, weight, x_scale, weight_scale, output, False, block_size,
output_dtype=out_dtype, output_dtype=out_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment