Commit 268d8a77 authored by gaoqiong's avatar gaoqiong
Browse files

增加线性int8 gemm配置

parent 92545504
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# Adapted from https://github.com/sgl-project/sglang/pull/3730
import functools import functools
import json import json
import logging import logging
...@@ -336,22 +335,59 @@ def w8a8_block_int8_matmul( ...@@ -336,22 +335,59 @@ def w8a8_block_int8_matmul(
C_shape = A.shape[:-1] + (N,) C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype) C = A.new_empty(C_shape, dtype=output_dtype)
configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1]) #configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1])
if configs: #if configs:
# If an optimal configuration map has been found, look up the # # If an optimal configuration map has been found, look up the
# optimal config # # optimal config
config = configs[min(configs.keys(), key=lambda x: abs(x - M))] # config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else: #else:
# Default config # Default config
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1] # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1]
#print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1])) #print("block_size[0]:{},block_size[1]:{}".format(block_size[0],block_size[1]))
# config = {
# "BLOCK_SIZE_M": 32, #64
# "BLOCK_SIZE_N": block_size[0],
# "BLOCK_SIZE_K": block_size[1],
# "GROUP_SIZE_M": 32,
# "num_warps": 4,
# "num_stages": 3,
# }
if M<=64:
config = {
"BLOCK_SIZE_M": 16, #64
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
elif M<128:
config = { config = {
"BLOCK_SIZE_M": 32, #64 "BLOCK_SIZE_M": 32, #64
"BLOCK_SIZE_N": block_size[0], "BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": block_size[1], "BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32, "GROUP_SIZE_M": 2,
"num_warps": 4, "num_warps": 4,
"num_stages": 3, "num_stages": 0,
}
elif M<=256:
config = {
"BLOCK_SIZE_M": 64, #64
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 2,
"num_warps": 4,
"num_stages": 0,
}
else :
config = {
"BLOCK_SIZE_M": 64, #64
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 0,
} }
def grid(META): def grid(META):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment