Commit b7afba08 authored by yuguo's avatar yuguo
Browse files

[DCU] support block fp8 simu with int8 for MOE

parent 735227cd
......@@ -10,8 +10,8 @@ import torch
import transformer_engine_torch as tex
from ..constants import TE_DType
from ..utils import get_sm_count
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt import w8a8_block_int8_matmul, w8a8_block_int8_matmul_batched
from transformer_engine.pytorch.triton.blockwise_int8_gemm_nt_wgrad import w8a8_block_int8_matmul_wgrad, w8a8_block_int8_matmul_wgrad_batched
from ..tensor.quantized_tensor import Quantizer
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...debug.pytorch.debug_quantization import DebugQuantizer
......@@ -205,6 +205,63 @@ def general_grouped_gemm(
transa = layout[0] == "T"
transb = layout[1] == "T"
if int8_simulation_fp8 and (isinstance(A[0], Float8BlockwiseQTensorBase) or isinstance(B[0], Float8BlockwiseQTensorBase)):
assert len(set(m_splits)) == 1, "Int8 simulation groupgemm just surpport token pad as same as batchgemm for now."
assert not gelu, "GELU not supported with int8 simulation groupgemm."
assert bias is None, "Bias not supported with int8 simulation groupgemm."
assert not accumulate, "Accumulation not supported with int8 simulation groupgemm."
if layout == "TN":
qx_data = [
b._rowwise_data.view(dtype=torch.int8) for b in B
]
qw_data = [
a._rowwise_data.view(dtype=torch.int8) for a in A
]
ref_scales_x = [b._rowwise_scale_inv for b in B]
ref_scales_w = [a._rowwise_scale_inv for a in A]
y, _ = w8a8_block_int8_matmul_batched(
qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128],
output_dtype=out_dtype
)
return y, None, None
elif layout == "NN":
qdout_data = [
b._rowwise_data.view(dtype=torch.int8) for b in B
]
qw_data = [
a._columnwise_data.view(dtype=torch.int8) for a in A
]
ref_scales_dout = [b._rowwise_scale_inv for b in B]
ref_scales_w = [a._columnwise_scale_inv for a in A]
y, _ = w8a8_block_int8_matmul_batched(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128],
output_dtype=out_dtype
)
return y, None, None
elif layout == "NT":
qdout_data = [
b._columnwise_data.view(dtype=torch.int8) for b in B
]
qx_data = [
a._columnwise_data.view(dtype=torch.int8) for a in A
]
ref_scales_dout = [b._columnwise_scale_inv for b in B]
ref_scales_x = [a._columnwise_scale_inv for a in A]
y, _ = w8a8_block_int8_matmul_wgrad_batched(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, [128, 128],
output_dtype=out_dtype
)
return y, None, None
else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
empty_tensor = _empty_tensor()
empty_tensors = [empty_tensor] * num_gemms
......@@ -276,6 +333,9 @@ def batchgemm(
empty_tensor = torch.Tensor()
empty_tensors = [torch.Tensor()] * num_gemms
if int8_simulation_fp8:
assert 0, "If you want to use batchgemm in int8 simulation, please unset GROUPED_GEMM_BatchLinear and use moe groupgemm with pad token."
if gelu and not grad:
gelu_input = [
torch.empty_like(o, dtype=dtype, memory_format=torch.contiguous_format) for o in out
......
......@@ -338,7 +338,210 @@ def w8a8_block_int8_matmul(
return C,config
@triton.jit
def _w8a8_block_int8_matmul_batched(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_a_batch,
stride_am,
stride_ak,
stride_b_batch,
stride_bk,
stride_bn,
stride_c_batch,
stride_cm,
stride_cn,
stride_as_batch,
stride_As_m,
stride_As_k,
stride_bs_batch,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization,
and store the result in output tensor `C`.
"""
pid_mn = tl.program_id(axis=0)
pid_batch = tl.program_id(axis=1)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid_mn // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid_mn % group_size_m)
pid_n = (pid_mn % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
# offs_bsn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_bsn = pid_n * BLOCK_SIZE_N // group_n
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + pid_batch * stride_a_batch + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + pid_batch * stride_b_batch + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# a_ptrs = A + (offs_am[:, None] * stride_am)
# b_ptrs = B + (offs_bn[None, :] * stride_bn)
As_ptrs = As + pid_batch * stride_as_batch + offs_am * stride_As_m
# offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + pid_batch * stride_bs_batch + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + pid_batch * stride_c_batch + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
def w8a8_block_int8_matmul_batched(
A_list, B_list, As_list, Bs_list,
block_size, output_dtype=torch.float16, best_config=None
):
A = torch.stack(A_list).contiguous() # [B, M, K]
B = torch.stack(B_list).contiguous() # [B, N, K]
As = torch.stack(As_list).contiguous()
Bs = torch.stack(Bs_list).contiguous()
assert A.shape[-1] == B.shape[-1]
M = A.numel() // A.shape[-1] // A.shape[0]
batch, N, K = B.shape
block_n, block_k = block_size
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 1,
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
batch,
)
_w8a8_block_int8_matmul_batched[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(0),
A.stride(-2),
A.stride(-1),
B.stride(0),
B.stride(-1),
B.stride(-2),
C.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(0),
As.stride(-1),
As.stride(-2),
Bs.stride(0),
Bs.stride(-1),
Bs.stride(-2),
**config,
)
return C
def apply_w8a8_block_int8_linear_batched_helper(m: int,
n: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
bias: Optional[torch.Tensor] = None,
best_config:Optional[dict]=None):
batch = 4
q_input, x_scale,weight,weight_scale=_int8_gemm_helper(m=m,n=n,k=k,out_dtype=out_dtype,device=device,block_size=block_size)
q_input_b = [q_input.clone().contiguous() for i in range(batch)]
x_scale_b = [x_scale.clone().contiguous() for i in range(batch)]
weight_b = [weight.clone().contiguous() for i in range(batch)]
weight_scale_b = [weight_scale.clone().contiguous() for i in range(batch)]
# print(f"zhenggf, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
torch_output = native_w8a8_block_int8_matmul_batched(q_input_b, weight_b, x_scale_b, weight_scale_b, block_size)
# print(f"zhenggf, torch_output:{torch_output.shape}")
x_scale_b = [xs.permute(1, 0).contiguous() for xs in x_scale_b]
output = w8a8_block_int8_matmul_batched(
q_input_b, weight_b, x_scale_b, weight_scale_b, block_size,
output_dtype=out_dtype,
best_config=best_config
)
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
print("triton 精度检查不合格!!!")
else:
print("triton 精度检查合格")
# unit test end
def apply_w8a8_block_int8_linear_helper(m: int,
n: int,
......@@ -489,6 +692,24 @@ def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.b
C = C.reshape(origin_C_shape).to(output_dtype)
return C
def native_w8a8_block_int8_matmul_batched(A_list, B_list, As_list, Bs_list, block_size, output_dtype=torch.bfloat16):
"""
Batched version of native block-wise quantized matmul.
Args:
A_list (List[Tensor]): [B, M, K]
B_list (List[Tensor]): [B, N, K]
As_list (List[Tensor]): [B, M, K // block_k]
Bs_list (List[Tensor]): [B, N // block_n, K // block_k]
Returns:
Tensor: [B, M, N]
"""
results = []
for A, B, As, Bs in zip(A_list, B_list, As_list, Bs_list):
C = native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype)
results.append(C)
return torch.stack(results)
def main():
m1=[item if item < 17 else 1 << (item - 27) for item in range(1, 17)]
m2=[item<<2 if item <17 else (item - 8)<<3 for item in range(5, 29)]
......@@ -529,7 +750,7 @@ def main():
best_config = []
apply_w8a8_block_int8_linear_batched_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
cost_times.append(elapsed_time)
......
......@@ -154,6 +154,110 @@ def _w8a8_block_int8_matmul(
tl.store(c_ptrs, c, mask=c_mask)
@triton.jit
def _w8a8_block_int8_matmul_batched(
# Pointers to inputs and output
A,
B,
C,
As,
Bs,
# Shape for matmul
M,
N,
K,
# Block size for block-wise quantization
group_n,
group_k,
# Stride for inputs and output
stride_a_batch,
stride_am,
stride_ak,
stride_b_batch,
stride_bk,
stride_bn,
stride_c_batch,
stride_cm,
stride_cn,
stride_as_batch,
stride_As_m,
stride_As_k,
stride_bs_batch,
stride_Bs_k,
stride_Bs_n,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization,
and store the result in output tensor `C`.
"""
pid_mn = tl.program_id(axis=0)
pid_batch = tl.program_id(axis=1)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid_mn // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid_mn % group_size_m)
pid_n = (pid_mn % num_pid_in_group) // group_size_m
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_bsn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
# offs_bsn = pid_n * BLOCK_SIZE_N // group_n
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + pid_batch * stride_a_batch + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + pid_batch * stride_b_batch + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
As_ptrs = As + pid_batch * stride_as_batch + offs_am * stride_As_m
# offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + pid_batch * stride_bs_batch + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16:
c = accumulator.to(tl.float16)
else:
c = accumulator.to(tl.float32)
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + pid_batch * stride_c_batch + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
tl.store(c_ptrs, c, mask=c_mask)
@functools.lru_cache
def get_w8a8_block_int8_configs(
N: int, K: int, block_n: int, block_k: int
......@@ -338,6 +442,107 @@ def w8a8_block_int8_matmul_wgrad(
return C,config
def w8a8_block_int8_matmul_wgrad_batched(
A_list, B_list, As_list, Bs_list,
block_size, output_dtype=torch.float16, best_config=None
):
A = torch.stack(A_list).contiguous() # [B, M, K]
B = torch.stack(B_list).contiguous() # [B, N, K]
As = torch.stack(As_list).contiguous()
Bs = torch.stack(Bs_list).contiguous()
assert A.shape[-1] == B.shape[-1]
M = A.numel() // A.shape[-1] // A.shape[0]
batch, N, K = B.shape
block_n, block_k = block_size
C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype)
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": block_n,
"BLOCK_SIZE_K": block_k,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 1,
}
def grid(META):
return (
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
batch,
)
_w8a8_block_int8_matmul_batched[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(0),
A.stride(-2),
A.stride(-1),
B.stride(0),
B.stride(-1),
B.stride(-2),
C.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(0),
As.stride(-1),
As.stride(-2),
Bs.stride(0),
Bs.stride(-2),
Bs.stride(-1),
**config,
)
return C
def apply_w8a8_block_int8_linear_batched_helper(m: int,
n: int,
k: int,
out_dtype: Type[torch.dtype] = torch.float16,
device: str = "cuda",
block_size: List[int]=[128,128],
bias: Optional[torch.Tensor] = None,
best_config:Optional[dict]=None):
batch = 4
q_input, x_scale,weight,weight_scale=_int8_gemm_helper_b(m=m,n=n,k=k,out_dtype=out_dtype,device=device,block_size=block_size)
q_input_b = [q_input.clone().contiguous() for i in range(batch)]
x_scale_b = [x_scale.clone().contiguous() for i in range(batch)]
weight_b = [weight.clone().contiguous() for i in range(batch)]
weight_scale_b = [weight_scale.clone().contiguous() for i in range(batch)]
# print(f"zhenggf, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
torch_output = native_w8a8_block_int8_matmul_batched(q_input_b, weight_b, x_scale_b, weight_scale_b, block_size)
# print(f"zhenggf, torch_output:{torch_output.shape}")
x_scale_b = [xs.permute(1, 0).contiguous() for xs in x_scale_b]
weight_scale_b = [ws.permute(1, 0).contiguous() for ws in weight_scale_b]
# print(f"zhenggf 转置后传递给triton kernel, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
output = w8a8_block_int8_matmul_wgrad_batched(
q_input_b, weight_b, x_scale_b, weight_scale_b, block_size,
output_dtype=out_dtype,
best_config=best_config
)
if not torch.allclose(output, torch_output, rtol=1e-2, atol=5e-2):
print("triton 精度检查不合格!!!")
else:
print("triton 精度检查合格")
# unit test end
def apply_w8a8_block_int8_linear_helper(m: int,
......@@ -494,6 +699,23 @@ def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.b
C = C.reshape(origin_C_shape).to(output_dtype)
return C
def native_w8a8_block_int8_matmul_batched(A_list, B_list, As_list, Bs_list, block_size, output_dtype=torch.bfloat16):
"""
Batched version of native block-wise quantized matmul.
Args:
A_list (List[Tensor]): [B, M, K]
B_list (List[Tensor]): [B, N, K]
As_list (List[Tensor]): [B, M, K // block_k]
Bs_list (List[Tensor]): [B, N, K // block_k]
Returns:
Tensor: [B, M, N]
"""
results = []
for A, B, As, Bs in zip(A_list, B_list, As_list, Bs_list):
C = native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype)
results.append(C)
return torch.stack(results)
def main():
m1=[item if item < 17 else 1 << (item - 27) for item in range(1, 17)]
m2=[item<<2 if item <17 else (item - 8)<<3 for item in range(5, 29)]
......@@ -534,6 +756,7 @@ def main():
best_config = []
apply_w8a8_block_int8_linear_batched_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
output,elapsed_time,gpu_costtime,config=apply_w8a8_block_int8_linear_helper(m=m,n=n_list[i],k=k_list[i],block_size=block_size,out_dtype=out_dtype,best_config=best_config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment