Commit 3653fbfb authored by yuguo's avatar yuguo
Browse files

[DCU] fix in8 simul fp8 fused wgrad accumulation

parent ecdd8251
...@@ -79,9 +79,6 @@ def general_gemm( ...@@ -79,9 +79,6 @@ def general_gemm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128], qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128],
output_dtype=out_dtype output_dtype=out_dtype
) )
if accumulate:
assert out is not None
y = y + out
return y, None, None, None return y, None, None, None
elif layout == "NN": elif layout == "NN":
...@@ -98,9 +95,6 @@ def general_gemm( ...@@ -98,9 +95,6 @@ def general_gemm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128], qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128],
output_dtype=out_dtype output_dtype=out_dtype
) )
if accumulate:
assert out is not None
y = y + out
return y, None, None, None return y, None, None, None
elif layout == "NT": elif layout == "NT":
...@@ -113,14 +107,11 @@ def general_gemm( ...@@ -113,14 +107,11 @@ def general_gemm(
ref_scales_dout = B._columnwise_scale_inv ref_scales_dout = B._columnwise_scale_inv
ref_scales_x = A._columnwise_scale_inv ref_scales_x = A._columnwise_scale_inv
y, _ = w8a8_block_int8_matmul_wgrad( out, _ = w8a8_block_int8_matmul_wgrad(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, [128, 128], qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate, [128, 128],
output_dtype=out_dtype output_dtype=out_dtype
) )
if accumulate: return out, None, None, None
assert out is not None
y = y + out
return y, None, None, None
else: else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8") raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
...@@ -226,10 +217,6 @@ def general_grouped_gemm( ...@@ -226,10 +217,6 @@ def general_grouped_gemm(
qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128], qx_data, qw_data, ref_scales_x, ref_scales_w, [128, 128],
output_dtype=out_dtype output_dtype=out_dtype
) )
if accumulate:
assert out is not None
out = torch.stack(out).contiguous()
y = y + out
return y, None, None return y, None, None
elif layout == "NN": elif layout == "NN":
...@@ -246,10 +233,6 @@ def general_grouped_gemm( ...@@ -246,10 +233,6 @@ def general_grouped_gemm(
qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128], qdout_data, qw_data, ref_scales_dout, ref_scales_w, [128, 128],
output_dtype=out_dtype output_dtype=out_dtype
) )
if accumulate:
assert out is not None
out = torch.stack(out).contiguous()
y = y + out
return y, None, None return y, None, None
elif layout == "NT": elif layout == "NT":
...@@ -262,15 +245,11 @@ def general_grouped_gemm( ...@@ -262,15 +245,11 @@ def general_grouped_gemm(
ref_scales_dout = [b._columnwise_scale_inv for b in B] ref_scales_dout = [b._columnwise_scale_inv for b in B]
ref_scales_x = [a._columnwise_scale_inv for a in A] ref_scales_x = [a._columnwise_scale_inv for a in A]
y, _ = w8a8_block_int8_matmul_wgrad_batched( out, _ = w8a8_block_int8_matmul_wgrad_batched(
qdout_data, qx_data, ref_scales_dout, ref_scales_x, [128, 128], qdout_data, qx_data, ref_scales_dout, ref_scales_x, out, accumulate [128, 128],
output_dtype=out_dtype output_dtype=out_dtype
) )
if accumulate: return out, None, None
assert out is not None
out = torch.stack(out).contiguous()
y = y + out
return y, None, None
else: else:
raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8") raise ValueError(f"Unsupported layout {layout} in int8 simulation fp8")
......
...@@ -82,6 +82,7 @@ def _w8a8_block_int8_matmul( ...@@ -82,6 +82,7 @@ def _w8a8_block_int8_matmul(
stride_As_k, stride_As_k,
stride_Bs_k, stride_Bs_k,
stride_Bs_n, stride_Bs_n,
accumulate: tl.constexpr,
# Meta-parameters # Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
...@@ -151,6 +152,8 @@ def _w8a8_block_int8_matmul( ...@@ -151,6 +152,8 @@ def _w8a8_block_int8_matmul(
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
if accumulate:
c += tl.load(c_ptrs, mask=c_mask)
tl.store(c_ptrs, c, mask=c_mask) tl.store(c_ptrs, c, mask=c_mask)
...@@ -185,6 +188,7 @@ def _w8a8_block_int8_matmul_batched( ...@@ -185,6 +188,7 @@ def _w8a8_block_int8_matmul_batched(
stride_bs_batch, stride_bs_batch,
stride_Bs_k, stride_Bs_k,
stride_Bs_n, stride_Bs_n,
accumulate: tl.constexpr,
# Meta-parameters # Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
...@@ -256,6 +260,8 @@ def _w8a8_block_int8_matmul_batched( ...@@ -256,6 +260,8 @@ def _w8a8_block_int8_matmul_batched(
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C + pid_batch * stride_c_batch + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_ptrs = C + pid_batch * stride_c_batch + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
if accumulate:
c += tl.load(c_ptrs, mask=c_mask)
tl.store(c_ptrs, c, mask=c_mask) tl.store(c_ptrs, c, mask=c_mask)
@functools.lru_cache @functools.lru_cache
...@@ -304,6 +310,8 @@ def w8a8_block_int8_matmul_wgrad( ...@@ -304,6 +310,8 @@ def w8a8_block_int8_matmul_wgrad(
B: torch.Tensor, B: torch.Tensor,
As: torch.Tensor, As: torch.Tensor,
Bs: torch.Tensor, Bs: torch.Tensor,
C: torch.Tensor,
accumulate: bool,
block_size: List[int], block_size: List[int],
output_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float16,
best_config:Optional[dict]=None best_config:Optional[dict]=None
...@@ -338,6 +346,10 @@ def w8a8_block_int8_matmul_wgrad( ...@@ -338,6 +346,10 @@ def w8a8_block_int8_matmul_wgrad(
# assert triton.cdiv(N, block_n) == Bs.shape[0] # assert triton.cdiv(N, block_n) == Bs.shape[0]
# assert triton.cdiv(K, block_k) == Bs.shape[1] # assert triton.cdiv(K, block_k) == Bs.shape[1]
if accumulate:
assert C is not None
if C is None:
C_shape = A.shape[:-1] + (N,) C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype) C = A.new_empty(C_shape, dtype=output_dtype)
...@@ -435,6 +447,7 @@ def w8a8_block_int8_matmul_wgrad( ...@@ -435,6 +447,7 @@ def w8a8_block_int8_matmul_wgrad(
As.stride(0), As.stride(0),
Bs.stride(-2), Bs.stride(-2),
Bs.stride(-1), Bs.stride(-1),
accumulate,
# Bs.stride(1), # Bs.stride(1),
# Bs.stride(0), # Bs.stride(0),
# **config, # **config,
...@@ -445,7 +458,7 @@ def w8a8_block_int8_matmul_wgrad( ...@@ -445,7 +458,7 @@ def w8a8_block_int8_matmul_wgrad(
def w8a8_block_int8_matmul_wgrad_batched( def w8a8_block_int8_matmul_wgrad_batched(
A_list, B_list, As_list, Bs_list, A_list, B_list, As_list, Bs_list, C_list, accumulate,
block_size, output_dtype=torch.float16, best_config=None block_size, output_dtype=torch.float16, best_config=None
): ):
A = torch.stack(A_list).contiguous() # [B, M, K] A = torch.stack(A_list).contiguous() # [B, M, K]
...@@ -462,8 +475,17 @@ def w8a8_block_int8_matmul_wgrad_batched( ...@@ -462,8 +475,17 @@ def w8a8_block_int8_matmul_wgrad_batched(
batch, N, K = B.shape batch, N, K = B.shape
block_n, block_k = block_size block_n, block_k = block_size
if accumulate:
if C_list is None:
assert False
else:
C = torch.stack(C_list).contiguous()
else:
if C_list is None:
C_shape = A.shape[:-1] + (N,) C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype) C = A.new_empty(C_shape, dtype=output_dtype)
else:
C = torch.stack(C_list).contiguous()
config = { config = {
"BLOCK_SIZE_M": 64, "BLOCK_SIZE_M": 64,
...@@ -506,6 +528,7 @@ def w8a8_block_int8_matmul_wgrad_batched( ...@@ -506,6 +528,7 @@ def w8a8_block_int8_matmul_wgrad_batched(
Bs.stride(0), Bs.stride(0),
Bs.stride(-2), Bs.stride(-2),
Bs.stride(-1), Bs.stride(-1),
accumulate,
**config, **config,
) )
...@@ -527,6 +550,10 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int, ...@@ -527,6 +550,10 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
x_scale_b = [x_scale.clone().contiguous() for i in range(batch)] x_scale_b = [x_scale.clone().contiguous() for i in range(batch)]
weight_b = [weight.clone().contiguous() for i in range(batch)] weight_b = [weight.clone().contiguous() for i in range(batch)]
weight_scale_b = [weight_scale.clone().contiguous() for i in range(batch)] weight_scale_b = [weight_scale.clone().contiguous() for i in range(batch)]
N, K = weight.shape
C_shape = q_input.shape[:-1] + (N,)
output = [q_input.new_empty(C_shape, dtype=out_dtype) for i in range(batch)]
# print(f"zhenggf, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}") # print(f"zhenggf, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
torch_output = native_w8a8_block_int8_matmul_batched(q_input_b, weight_b, x_scale_b, weight_scale_b, block_size) torch_output = native_w8a8_block_int8_matmul_batched(q_input_b, weight_b, x_scale_b, weight_scale_b, block_size)
...@@ -537,7 +564,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int, ...@@ -537,7 +564,7 @@ def apply_w8a8_block_int8_linear_batched_helper(m: int,
# print(f"zhenggf 转置后传递给triton kernel, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}") # print(f"zhenggf 转置后传递给triton kernel, q_input_b:{q_input_b.shape}, x_scale_b:{x_scale_b.shape}, weight_b:{weight_b.shape}, weight_scale_b:{weight_scale_b.shape}")
output = w8a8_block_int8_matmul_wgrad_batched( output = w8a8_block_int8_matmul_wgrad_batched(
q_input_b, weight_b, x_scale_b, weight_scale_b, block_size, q_input_b, weight_b, x_scale_b, weight_scale_b, output, False, block_size,
output_dtype=out_dtype, output_dtype=out_dtype,
best_config=best_config best_config=best_config
) )
...@@ -568,9 +595,12 @@ def apply_w8a8_block_int8_linear_helper(m: int, ...@@ -568,9 +595,12 @@ def apply_w8a8_block_int8_linear_helper(m: int,
x_scale = x_scale.permute(1, 0).contiguous() x_scale = x_scale.permute(1, 0).contiguous()
weight_scale = weight_scale.permute(1, 0).contiguous() weight_scale = weight_scale.permute(1, 0).contiguous()
N, K = weight.shape
C_shape = q_input.shape[:-1] + (N,)
output = q_input.new_empty(C_shape, dtype=out_dtype)
print(f"zhenggf 转置后传递给triton kernel, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}") print(f"zhenggf 转置后传递给triton kernel, q_input:{q_input.shape}, x_scale:{x_scale.shape}, weight:{weight.shape}, weight_scale:{weight_scale.shape}")
output,config = w8a8_block_int8_matmul_wgrad( output,config = w8a8_block_int8_matmul_wgrad(
q_input, weight, x_scale, weight_scale, block_size, q_input, weight, x_scale, weight_scale, output, False, block_size,
output_dtype=out_dtype, output_dtype=out_dtype,
best_config=best_config best_config=best_config
) )
...@@ -587,7 +617,7 @@ def apply_w8a8_block_int8_linear_helper(m: int, ...@@ -587,7 +617,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
with torch.cuda.graph(g): with torch.cuda.graph(g):
for it in range(1000): for it in range(1000):
output,_ = w8a8_block_int8_matmul_wgrad( output,_ = w8a8_block_int8_matmul_wgrad(
q_input, weight, x_scale, weight_scale, block_size, q_input, weight, x_scale, weight_scale, output, False, block_size,
output_dtype=out_dtype, output_dtype=out_dtype,
best_config=best_config best_config=best_config
) )
...@@ -600,7 +630,7 @@ def apply_w8a8_block_int8_linear_helper(m: int, ...@@ -600,7 +630,7 @@ def apply_w8a8_block_int8_linear_helper(m: int,
elapsed_time = round((end_time_ - start_time_) *1000 ,7)# 计算耗时 elapsed_time = round((end_time_ - start_time_) *1000 ,7)# 计算耗时
print("_time:{} us\n".format(elapsed_time)) print("_time:{} us\n".format(elapsed_time))
quantiles = [0.5, 0.2, 0.8] quantiles = [0.5, 0.2, 0.8]
gpu_costtime = triton.testing.do_bench(lambda:w8a8_block_int8_matmul_wgrad(q_input, weight, x_scale, weight_scale, block_size,output_dtype=out_dtype,best_config=best_config),quantiles=None, return_mode="mean")*1000 gpu_costtime = triton.testing.do_bench(lambda:w8a8_block_int8_matmul_wgrad(q_input, weight, x_scale, weight_scale, output, False, block_size,output_dtype=out_dtype,best_config=best_config),quantiles=None, return_mode="mean")*1000
if bias is not None: if bias is not None:
output = output + bias output = output + bias
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment