"vllm/utils/__init__.py" did not exist on "f52afe3f682c78f5f99337e9f81ef9c4349b69a6"
Commit 8fc15e04 authored by gaoqiong's avatar gaoqiong
Browse files

deepseek_v3/r1 int8 量化首字调优

parent f5f9f42f
......@@ -42,7 +42,10 @@ if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.curren
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4}, #14
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #15
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #32
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 2,"num_stages": 0,"num_warps": 4}, #256
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 2,"num_stages": 0,"num_warps": 4},#1024
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 2,"kpack": 2,"num_stages": 0,"num_warps": 8},#8192
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "kpack": 1,"num_stages": 0,"num_warps": 8}
]
stage2_best_config=[
......@@ -62,7 +65,11 @@ if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.curren
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #14
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #15
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #16
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 1,"num_stages": 0,"num_warps": 4}, #32
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 4} ,#256
{"BLOCK_SIZE_M": 64,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 4},#1024
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4},# 8192
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"kpack": 1,"num_stages": 0,"num_warps": 4}
]
else:
stage1_best_config=[
......@@ -83,7 +90,10 @@ else:
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 2,"num_stages": 0,"num_warps": 2}, #14
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 32,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 8,"num_stages": 0,"num_warps": 2}, #15
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 4}, #32
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 8},#256
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 8},#1024
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 8},#8192
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 128,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 8},
]
stage2_best_config=[
......@@ -103,7 +113,11 @@ else:
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2},#13
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #14
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #15
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #16
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 128,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"num_stages": 0,"num_warps": 2}, #32
{"BLOCK_SIZE_M": 16,"BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 4}, #256
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0, "num_warps": 4}, #1024
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0, "num_warps": 4}, #8192
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 4}
]
@triton.jit
......@@ -1662,8 +1676,9 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# so the cache size and config are already set correctly and
# do not need to be adjusted.
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
intermediate_cache2 = intermediate_cache2[:tokens_in_chunk * topk_ids.shape[1]]
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
if not use_int8_w8a8:
config = get_config_func(tokens_in_chunk)
curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
......@@ -1677,24 +1692,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config =stage1_best_config[15]
elif m<=64:
config =stage1_best_config[16]
elif m<256:
config ={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0,
"num_warps": 4
}
elif m<=256:
config =stage1_best_config[17]
elif m<=1024:
config =stage1_best_config[18]
elif m<=8192:
config =stage1_best_config[19]
else:
config ={
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_stages": 0,
"num_warps": 4
}
config =stage1_best_config[20]
if moe_ep_size == 1:
if use_int4_w4a16:
......@@ -1740,24 +1745,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config =stage2_best_config[15]
elif m<=64:
config =stage2_best_config[16]
elif m<256:
config ={
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_stages": 0,
"num_warps": 4
}
elif m<=256:
config =stage2_best_config[17]
elif m<=1024:
config =stage2_best_config[18]
elif m<=8192:
config =stage2_best_config[19]
else:
config ={
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_stages": 0,
"num_warps": 4
}
config =stage2_best_config[20]
invoke_fused_moe_kernel(intermediate_cache2,
w2,
......
......@@ -68,7 +68,6 @@ def per_token_quant_int8(x):
return x_q, scales
@triton.jit
def _per_token_group_quant_int8(
# Pointers to inputs and output
......@@ -76,9 +75,12 @@ def _per_token_group_quant_int8(
y_q_ptr,
y_s_ptr,
# Stride of input
y_stride,
# Collums of input
N,
group_size,
# M,
# K,
# # Collums of input
# N,
SIZE,
# Avoid to divide zero
eps,
# Information for int8
......@@ -86,6 +88,7 @@ def _per_token_group_quant_int8(
int8_max,
# Meta-parameters
BLOCK: tl.constexpr,
s_num : tl.constexpr,
):
"""A Triton-accelerated function to perform
per-token-group quantization on a tensor.
......@@ -93,21 +96,26 @@ def _per_token_group_quant_int8(
"""
# Map the program id to the row of X and Y it should compute.
g_id = tl.program_id(0)
y_ptr += g_id * y_stride
y_q_ptr += g_id * y_stride
y_s_ptr += g_id
y_ptr += g_id * BLOCK
y_q_ptr += g_id * BLOCK
y_s_ptr += g_id * s_num
cols = tl.arange(0, BLOCK) # N <= BLOCK
mask = cols < N
s_cols = tl.arange(0, s_num)
mask = g_id * BLOCK + cols < SIZE
y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
y = tl.reshape(y, (s_num, 128))
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / int8_max
_absmax = tl.maximum(tl.max(tl.abs(y), axis=1), eps)
y_s = (_absmax / int8_max).reshape(s_num, 1)
y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty)
y_q = tl.reshape(y_q, (s_num*128))
y_s = tl.reshape(y_s, (s_num))
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
tl.store(y_s_ptr + s_cols, y_s.to(y_s_ptr.dtype.element_ty))
def per_token_group_quant_int8(
......@@ -139,30 +147,38 @@ def per_token_group_quant_int8(
int8_min = iinfo.min
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
M = x.numel() // group_size
N = group_size
m = x.shape[0]
if m<=16:
config={"BLOCK":128,"s_num":1,"num_warps":1,"num_stages":1}
elif m<=256:
config={"BLOCK":1024,"s_num":8,"num_warps":4,"num_stages":1}
else:
config={"BLOCK":2048,"s_num":16,"num_warps":4,"num_stages":2}
grid = lambda META: (
triton.cdiv(x.numel(), META['BLOCK']),
)
x_s = torch.empty(
x.shape[:-1] + (x.shape[-1] // group_size,),
device=x.device,
dtype=torch.float32,
)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
num_stages = 1
_per_token_group_quant_int8[(M,)](
_per_token_group_quant_int8[grid](
x,
x_q,
x_s,
group_size,
N,
# M,
# K,
# N,
x.numel(),
eps,
int8_min=int8_min,
int8_max=int8_max,
BLOCK=BLOCK,
num_warps=num_warps,
num_stages=num_stages,
**config
)
return x_q, x_s
......@@ -458,59 +474,6 @@ def w8a8_block_int8_matmul(
return C
def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16):
"""matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A = A.to(torch.float32)
B = B.to(torch.float32)
assert A.shape[-1] == B.shape[-1]
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
assert len(block_size) == 2
block_n, block_k = block_size[0], block_size[1]
assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1]
assert A.shape[:-1] == As.shape[:-1]
M = A.numel() // A.shape[-1]
N, K = B.shape
origin_C_shape = A.shape[:-1] + (N,)
A = A.reshape(M, A.shape[-1])
As = As.reshape(M, As.shape[-1])
n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k
assert n_tiles == Bs.shape[0]
assert k_tiles == Bs.shape[1]
C_shape = (M, N)
C = torch.zeros(C_shape, dtype=torch.float32, device=A.device)
A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)]
B_tiles = [
[
B[
j * block_n : min((j + 1) * block_n, N),
i * block_k : min((i + 1) * block_k, K),
]
for i in range(k_tiles)
]
for j in range(n_tiles)
]
C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)]
As_tiles = [As[:, i : i + 1] for i in range(k_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
a = A_tiles[i]
b = B_tiles[j][i]
c = C_tiles[j]
s = As_tiles[i] * Bs[j][i]
c[:, :] += torch.matmul(a, b.t()) * s
C = C.reshape(origin_C_shape).to(output_dtype)
return C
def apply_w8a8_block_int8_linear(
input: torch.Tensor,
weight: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment