Commit bb8cf71b authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.4'

parents 429226fb f5349823
......@@ -2,7 +2,12 @@ import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
import lightop
import warnings
try:
import lightop
enable_lightop = True
except ImportError:
enable_lightop = False
from transformer_engine.pytorch import get_device_compute_capability
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
......@@ -197,7 +202,7 @@ def cublas_gemm_fp8_blockwise_case_fw(
ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128:
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype
......@@ -380,7 +385,7 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
ref_scales_dout = qdout._columnwise_scale_inv if dout_columnwise else qdout._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128:
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len],
output_dtype=dx_dtype
......@@ -564,7 +569,7 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
if get_device_compute_capability() < (9, 3) or block_len != 128:
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None,
accumulate, [block_len, block_len],
......
......@@ -8,7 +8,12 @@ from typing import Iterable, Optional, Tuple, Union, List
import os
import torch
import transformer_engine_torch as tex
import lightop
import warnings
try:
import lightop
enable_lightop = True
except ImportError:
enable_lightop = False
from ..constants import TE_DType, TE_DType_To_Torch
from ..utils import get_sm_count, _empty_tensor
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched
......@@ -86,7 +91,7 @@ def general_gemm(
)
ref_scales_x = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128:
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
......@@ -106,7 +111,7 @@ def general_gemm(
)
ref_scales_dout = B._rowwise_scale_inv
ref_scales_w = A._columnwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128:
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
......@@ -126,7 +131,7 @@ def general_gemm(
)
ref_scales_dout = B._columnwise_scale_inv
ref_scales_x = A._columnwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128:
if get_device_compute_capability() < (9, 3) or block_len != 128 or not enable_lightop:
out, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
......
......@@ -11,7 +11,12 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools
import logging
import lightop
import warnings
try:
import lightop
enable_lightop = True
except ImportError:
enable_lightop = False
from transformer_engine.pytorch.utils import get_device_compute_capability
logger = logging.getLogger(__name__)
......@@ -610,7 +615,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
torch_output = native_w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size, out_dtype)
x_scale = x_scale.permute(1, 0).contiguous()
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
if get_device_compute_capability() < (9, 3) or block_size[1] != 128 or not enable_lightop:
output,config = w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size,
output_dtype=out_dtype,
......
......@@ -11,7 +11,12 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools
import logging
import lightop
import warnings
try:
import lightop
enable_lightop = True
except ImportError:
enable_lightop = False
from transformer_engine.pytorch.utils import get_device_compute_capability
logger = logging.getLogger(__name__)
......@@ -464,7 +469,7 @@ def w8a8_block_int8_matmul_wgrad_batched_native(
):
for i in range(len(C_list)):
assert C_list[i] is not None
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
if get_device_compute_capability() < (9, 3) or block_size[1] != 128 or not enable_lightop:
C_list[i], config = w8a8_block_int8_matmul_wgrad(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size,
output_dtype=output_dtype,
......@@ -663,7 +668,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
C_shape = q_input.shape[:-1] + (N,)
output = q_input.new_empty(C_shape, dtype=out_dtype)
print(f"zhenggf 转置后传递给triton kernel, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}")
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
if get_device_compute_capability() < (9, 3) or block_size[1] != 128 or not enable_lightop:
output,config = w8a8_block_int8_matmul_wgrad(
q_input, weight, x_scale, weight_scale, output, False, block_size,
output_dtype=out_dtype,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment