Commit fdf60506 authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.4'

parents 403db136 3b1f30a9
......@@ -2,7 +2,8 @@ import pytest
import torch
import transformer_engine as te
import transformer_engine_torch as tex
import w8a8_matmul_extension
from transformer_engine.pytorch import get_device_compute_capability
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import (FP8GlobalStateManager, blockwise_fp8_block_len)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
......@@ -196,10 +197,16 @@ def cublas_gemm_fp8_blockwise_case_fw(
ref_scales_x = qx._columnwise_scale_inv if x_columnwise else qx._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128:
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype
)
else:
y = w8a8_matmul_extension.w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [block_len, block_len],
output_dtype=out_dtype
)
# print("int8 gemm output: ", y)
# print("int8 gemm output shape: ", y.shape)
......@@ -374,10 +381,16 @@ def cublas_gemm_fp8_blockwise_case_bw_xgrad(
ref_scales_dout = qdout._columnwise_scale_inv if dout_columnwise else qdout._rowwise_scale_inv
ref_scales_w = qw._columnwise_scale_inv if w_columnwise else qw._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128:
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len],
output_dtype=dx_dtype
)
else:
y = w8a8_matmul_extension.w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [block_len, block_len],
output_dtype=dx_dtype
)
# print("int8 gemm dx: ", y)
......@@ -553,12 +566,18 @@ def cublas_gemm_fp8_blockwise_case_bw_wgrad(
# print(f"qdout_data.shape: {qdout_data.shape}, qx_data.shape: {qx_data.shape}")
# print(f"ref_scales_dout.shape: {ref_scales_dout.shape}, ref_scales_x.shape: {ref_scales_x.shape}")
if get_device_compute_capability() < (9, 3) or block_len != 128:
y, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None,
accumulate, [block_len, block_len],
output_dtype=dw_dtype
)
else:
y = w8a8_matmul_extension.w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, dw.clone() if accumulate else None,
accumulate, [block_len, block_len],
output_dtype=dw_dtype
)
# print("int8 gemm dw: ",y)
......
......@@ -8,6 +8,7 @@ from typing import Iterable, Optional, Tuple, Union, List
import os
import torch
import transformer_engine_torch as tex
import w8a8_matmul_extension
from ..constants import TE_DType
from ..utils import get_sm_count, _empty_tensor
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched
......@@ -75,11 +76,16 @@ def general_gemm(
)
ref_scales_x = B._rowwise_scale_inv
ref_scales_w = A._rowwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128:
y, _ = w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
else:
y = w8a8_matmul_extension.w8a8_block_int8_matmul(
qx_data, qw_data, ref_scales_x, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
return y, None, None, None
elif layout == "NN":
......@@ -91,11 +97,16 @@ def general_gemm(
)
ref_scales_dout = B._rowwise_scale_inv
ref_scales_w = A._columnwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128:
y, _ = w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
else:
y = w8a8_matmul_extension.w8a8_block_int8_matmul(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
return y, None, None, None
elif layout == "NT":
......@@ -107,11 +118,16 @@ def general_gemm(
)
ref_scales_dout = B._columnwise_scale_inv
ref_scales_x = A._columnwise_scale_inv
if get_device_compute_capability() < (9, 3) or block_len != 128:
out, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
else:
out = w8a8_matmul_extension.w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [blockwise_fp8_block_len, blockwise_fp8_block_len],
output_dtype=out_dtype
)
return out, None, None, None
else:
......
......@@ -11,6 +11,8 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools
import logging
import w8a8_matmul_extension
from transformer_engine.pytorch.utils import get_device_compute_capability
logger = logging.getLogger(__name__)
......@@ -574,7 +576,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
output = [q_input.new_empty(C_shape, dtype=out_dtype) for i in range(batch)]
output = torch.stack(output).contiguous()
torch_output = native_w8a8_block_int8_matmul_batched(q_input_b, weight_b, x_scale_b, weight_scale_b, block_size)
torch_output = native_w8a8_block_int8_matmul_batched(q_input_b, weight_b, x_scale_b, weight_scale_b, block_size, out_dtype)
torch_output = torch_output.view(-1, torch_output.size(-1))
# print(f"zhenggf, torch_output:{torch_output.shape}")
......@@ -605,16 +607,20 @@ def apply_w8a8_block_int8_linear_helper(m: int,
q_input, x_scale,weight,weight_scale=_int8_gemm_helper(m=m,n=n,k=k,out_dtype=out_dtype,device=device,block_size=block_size)
print(f"zhenggf, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}")
torch_output = native_w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size)
torch_output = native_w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size, out_dtype)
x_scale = x_scale.permute(1, 0).contiguous()
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
output,config = w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size,
output_dtype=out_dtype,
best_config=best_config
)
else:
output = w8a8_matmul_extension.w8a8_block_int8_matmul(
q_input, weight, x_scale, weight_scale, block_size,
output_dtype=out_dtype
)
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
print("triton 精度检查不合格!!!")
......@@ -622,6 +628,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
else:
print("triton 精度检查合格")
# unit test end
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for it in range(1000):
......@@ -800,6 +807,7 @@ def main():
best_config = []
apply_w8a8_block_int8_linear_batched_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
cost_times.append(elapsed_time)
......@@ -816,8 +824,10 @@ def main():
_configs_block_num_warps.append(config.num_warps)
_configs_block_num_stages.append(config.num_stages)
# _configs_kpack.append(config['kpack'])
else:
apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
# 创建一个包含这三个列表的 DataFrame
df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes,
'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k,
......
......@@ -11,7 +11,8 @@ from transformer_engine.pytorch.triton.per_token_group_quant import _int8_gemm_h
from transformer_engine.pytorch.fp8 import blockwise_fp8_block_len
import functools
import logging
import w8a8_matmul_extension
from transformer_engine.pytorch.utils import get_device_compute_capability
logger = logging.getLogger(__name__)
device_name=torch.cuda.get_device_properties('cuda').name.replace(" ","_")
......@@ -463,11 +464,17 @@ def w8a8_block_int8_matmul_wgrad_batched_native(
):
for i in range(len(C_list)):
assert C_list[i] is not None
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
C_list[i], config = w8a8_block_int8_matmul_wgrad(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size,
output_dtype=output_dtype,
best_config=best_config
)
else:
C_list[i] = w8a8_matmul_extension.w8a8_block_int8_matmul_wgrad(
A_list[i], B_list[i], As_list[i], Bs_list[i], C_list[i], accumulate, block_size,
output_dtype=output_dtype
)
return C_list
def w8a8_block_int8_matmul_wgrad_batched(
......@@ -613,7 +620,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
output = [q_input.new_empty(C_shape, dtype=out_dtype) for i in range(batch)]
# print(f"zhenggf, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
torch_output = native_w8a8_block_int8_matmul_batched(q_input_b, weight_b, x_scale_b, weight_scale_b, block_size)
torch_output = native_w8a8_block_int8_matmul_batched(q_input_b, weight_b, x_scale_b, weight_scale_b, block_size, out_dtype)
# print(f"zhenggf, torch_output:{torch_output.shape}")
x_scale_b = [xs.permute(1, 0).contiguous() for xs in x_scale_b]
......@@ -648,7 +655,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
q_input, x_scale,weight,weight_scale=_int8_gemm_helper_b(m=m,n=n,k=k,out_dtype=out_dtype,device=device,block_size=block_size)
print(f"zhenggf, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}")
torch_output = native_w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size)
torch_output = native_w8a8_block_int8_matmul(q_input, weight, x_scale, weight_scale, block_size, out_dtype)
x_scale = x_scale.permute(1, 0).contiguous()
weight_scale = weight_scale.permute(1, 0).contiguous()
......@@ -657,13 +664,17 @@ def apply_w8a8_block_int8_linear_helper(m: int,
C_shape = q_input.shape[:-1] + (N,)
output = q_input.new_empty(C_shape, dtype=out_dtype)
print(f"zhenggf 转置后传递给triton kernel, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}")
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
output,config = w8a8_block_int8_matmul_wgrad(
q_input, weight, x_scale, weight_scale, output, False, block_size,
output_dtype=out_dtype,
best_config=best_config
)
else:
output = w8a8_matmul_extension.w8a8_block_int8_matmul_wgrad(
q_input, weight, x_scale, weight_scale, output, False, block_size,
output_dtype=out_dtype
)
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
print("triton 精度检查不合格!!!")
......@@ -671,6 +682,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
else:
print("triton 精度检查合格")
# unit test end
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
for it in range(1000):
......@@ -850,7 +862,7 @@ def main():
best_config = []
apply_w8a8_block_int8_linear_batched_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
cost_times.append(elapsed_time)
......@@ -867,8 +879,10 @@ def main():
_configs_block_num_warps.append(config.num_warps)
_configs_block_num_stages.append(config.num_stages)
# _configs_kpack.append(config['kpack'])
else:
apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
if get_device_compute_capability() < (9, 3) or block_size[1] != 128:
# 创建一个包含这三个列表的 DataFrame
df = pd.DataFrame({'m':_m,'n':_n,'k':_k,'线性层gemm量化算子耗时': cost_times,'GPU算子耗时':gpu_costtimes,
'BLOCK_SIZE_M':_configs_block_m,'BLOCK_SIZE_N':_configs_block_n,'BLOCK_SIZE_K':_configs_block_k,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment