Commit 90f05cd6 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.8.5.post1-dev-w8a8' into 'v0.8.5.post1-dev'

增加w8a8 线性gemm triton优化

See merge request dcutoolkit/deeplearing/vllm!128
parents 2664c459 f52afe3f
...@@ -300,8 +300,8 @@ def _w8a8_block_int8_matmul( ...@@ -300,8 +300,8 @@ def _w8a8_block_int8_matmul(
GROUP_SIZE_M: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
): ):
"""Triton-accelerated function used to perform linear operations (dot """Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and product) on input tensors `A` and `B` with block-wise quantization,
store the result in output tensor `C`. and store the result in output tensor `C`.
""" """
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
...@@ -316,16 +316,29 @@ def _w8a8_block_int8_matmul( ...@@ -316,16 +316,29 @@ def _w8a8_block_int8_matmul(
offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
# offs_bsn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_bsn = pid_n * BLOCK_SIZE_N // group_n
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
# a_ptrs = A + (offs_am[:, None] * stride_am)
# b_ptrs = B + (offs_bn[None, :] * stride_bn)
As_ptrs = As + offs_am * stride_As_m As_ptrs = As + offs_am * stride_As_m
offs_bsn = offs_bn // group_n # offs_bsn = offs_bn // group_n
Bs_ptrs = Bs + offs_bsn * stride_Bs_n Bs_ptrs = Bs + offs_bsn * stride_Bs_n
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
a = tl.load(a_ptrs, a = tl.load(a_ptrs,
mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
other=0.0) other=0.0)
...@@ -333,16 +346,13 @@ def _w8a8_block_int8_matmul( ...@@ -333,16 +346,13 @@ def _w8a8_block_int8_matmul(
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0) other=0.0)
k_start = k * BLOCK_SIZE_K
offs_ks = k_start // group_k
a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :]
None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K * stride_ak a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk b_ptrs += BLOCK_SIZE_K * stride_bk
if C.dtype.element_ty == tl.bfloat16: if C.dtype.element_ty == tl.bfloat16:
c = accumulator.to(tl.bfloat16) c = accumulator.to(tl.bfloat16)
elif C.dtype.element_ty == tl.float16: elif C.dtype.element_ty == tl.float16:
...@@ -436,27 +446,8 @@ def w8a8_block_int8_matmul( ...@@ -436,27 +446,8 @@ def w8a8_block_int8_matmul(
C_shape = A.shape[:-1] + (N, ) C_shape = A.shape[:-1] + (N, )
C = A.new_empty(C_shape, dtype=output_dtype) C = A.new_empty(C_shape, dtype=output_dtype)
# configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
# if configs:
# # If an optimal configuration map has been found, look up the
# # optimal config
# config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
# else:
# # Default config
# # Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
# config = {
# "BLOCK_SIZE_M": 64,
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
#print("W8A8_TRITONJSON.triton_json_list[0]:",W8A8_TRITONJSON.triton_json_list[0])
if len(W8A8_TRITONJSON.triton_json_list)==0: if len(W8A8_TRITONJSON.triton_json_list)==0:
config=None config=None
#print("len(W8A8_TRITONJSON.triton_json_list)=0:",len(W8A8_TRITONJSON.triton_json_list)) triton_json
elif f"1_{N}_{K}_block[{block_n},{block_k}]" in W8A8_TRITONJSON.triton_json_list[0]: elif f"1_{N}_{K}_block[{block_n},{block_k}]" in W8A8_TRITONJSON.triton_json_list[0]:
if M<=16: if M<=16:
...@@ -480,12 +471,13 @@ def w8a8_block_int8_matmul( ...@@ -480,12 +471,13 @@ def w8a8_block_int8_matmul(
m_=4096 m_=4096
else: else:
m_=8192 m_=8192
#print("==================m:{},n:{},k:{}".format(M,N,K))
config=W8A8_TRITONJSON.triton_json_list[0][f"{m_}_{N}_{K}_block[{block_n},{block_k}]"] config=W8A8_TRITONJSON.triton_json_list[0][f"{m_}_{N}_{K}_block[{block_n},{block_k}]"]
else: else:
config=None config=None
if config==None:
# print("m:{},n:{},k:{}".format(M,N,K)) # print("m:{},n:{},k:{}".format(M,N,K))
# print("config not found!") # print("config not found!")
......
...@@ -1825,6 +1825,7 @@ class W8a8GetCacheJSON: ...@@ -1825,6 +1825,7 @@ class W8a8GetCacheJSON:
'kpack': int(sub_value["kpack"]), 'kpack': int(sub_value["kpack"]),
'num_stages':int(sub_value['num_stages']), 'num_stages':int(sub_value['num_stages']),
'num_warps':int(sub_value['num_warps']), 'num_warps':int(sub_value['num_warps']),
'enable_mmacfuse':int(sub_value['enable_mmacfuse']),
} }
configs_dict[configs_key]=configs_value configs_dict[configs_key]=configs_value
return configs_dict return configs_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment