Commit 48542418 authored by renzhc's avatar renzhc
Browse files

fix lightop gemm smooth and add some trivial changes

parent 1a73f6a3
......@@ -332,6 +332,18 @@ def rocblas_scaled_mm(a: torch.Tensor,
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def blaslt_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
m = a.shape[0]
n = b.shape[0]
k = a.shape[1]
_, out = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a, scale_b, m, n, k, 'NT', out_dtype)
return out
def triton_int8_gemm_helper(m: int,
n: int,
k: int,
......
......@@ -725,6 +725,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
self.packed_recv_count = self.handle = None
return combined_hidden_states, event, hook
@torch._dynamo.disable()
def _get_buffer(self):
DeepEPBuffer.set_dispatch_mode_as_low_latency()
return DeepEPBuffer.get_deepep_buffer(
......
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import os
import logging
from contextlib import suppress
from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, cast
......@@ -46,6 +46,9 @@ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt import _custom_ops as ops
from sglang.srt.utils import W8a8GetCacheJSON # TODO: remove vllm dependency
logger = logging.getLogger(__name__)
__all__ = ["CompressedTensorsLinearMethod"]
......@@ -590,8 +593,33 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def __init__(self, quantization_config: CompressedTensorsConfig):
self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0]
k=layer.weight.shape[1]
if self.w8a8_strategy==1:
if [n,k] not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append([n,k])
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
self.tritonsingleton.triton_json_dict.update(configs_dict)
for key, value in configs_dict.items():
m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
elif self.w8a8_strategy==3:
layer.weight.data = layer.weight.data.T
else:
weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1)
layer.weight.data=_weight
self.tritonsingleton.gen_model_json()
layer.scheme.process_weights_after_loading(layer)
def create_weights(
......
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Callable, Optional
import torch
......@@ -19,11 +20,14 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
from sglang.srt.utils import is_cuda
from lmslim import quant_ops
from sglang.srt import _custom_ops as ops
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import int8_scaled_mm
# TODO: remove vllm deps
from sglang.srt.utils import W8a8GetCacheJSON
W8A8_TRITONJSON=W8a8GetCacheJSON()
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
......@@ -33,6 +37,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) # TODO
@classmethod
def get_min_capability(cls) -> int:
......@@ -163,14 +168,70 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
)
layer.register_parameter("input_zero_point", input_zero_point)
@torch._dynamo.disable()
def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
# TODO: add cutlass_scaled_mm_azp support
x_q, x_scale = per_token_quant_int8(x)
# TODO: fix with lmslim/lightop
return quant_ops.triton_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
)
# return quant_ops.custom_scaled_mm(x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias)
if self.w8a8_strategy==1:
m=x_q.shape[0]
k=x_q.shape[1]
n=layer.weight.shape[1]
if len(W8A8_TRITONJSON.triton_json_dict)==0:
best_config=None
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
if m<=16:
m_=m
elif m<=64:
m_= (m + 3) & -4 #取值到最近的4的倍数
elif m<=160:
m_=(m + 7) & -8
elif m<200: #256
m_=160
elif m<480: #512
m_=256
elif m<960: #1024
m_=512
elif m<2048:
m_=1024
elif m<4096:
m_=2048
elif m<6000:
m_=4096
else:
m_=8192
best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"]
else:
best_config=None
return ops.triton_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias, best_config=best_config
)
elif self.w8a8_strategy==2:
return ops.cutlass_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
elif self.w8a8_strategy==3:
return ops.blaslt_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=None)
else:
return ops.rocblas_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
......@@ -15,7 +15,7 @@ from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8,
per_token_quant_int8)
from sglang.srt import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON
from sglang.srt.utils import W8a8GetCacheJSON
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
import os
......@@ -157,7 +157,6 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
)
layer.register_parameter("weight_scale", weight_scale)
@torch._dynamo.disable() # TODO: 性能优化需要lmslim/lightop配合
def apply(
self,
layer: torch.nn.Module,
......@@ -227,6 +226,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
elif self.w8a8_strategy==3:
return ops.blaslt_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=None)
else:
return ops.rocblas_scaled_mm(x_q,
layer.weight,
......
......@@ -653,7 +653,7 @@ class ForwardBatch:
bs = self.batch_size,
)
else:
logger.info("SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES=0")
# logger.info("SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES=0")
create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
self.req_to_token_pool.req_to_token,
self.req_pool_indices,
......
......@@ -3528,3 +3528,158 @@ def cached_triton_kernel(key_fn=None):
return CachedKernel(fn, key_fn)
return decorator
# from vllm
class W8a8GetCacheJSON:
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(W8a8GetCacheJSON, cls).__new__(cls, *args, **kwargs)
cls._instance._initialize()
return cls._instance
def _initialize(self):
current_folder_path = os.path.dirname(os.path.abspath(__file__))
json_folder_path=current_folder_path+'/../../lmslim/configs/w8a8'
self.triton_json_dir=(os.getenv('TRITON_JSON_DIR', json_folder_path))
self.triton_json_dict={}
self.triton_moejson_dict={}
self.triton_json_list=[]
self.weight_shapes=[]
self.moe_weight_shapes=[]
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
device_name =arch_name+'_'+str(arch_cu)+'cu'
self.device_name=device_name
self.topk=1
self.quant_method=None
#析构函数,最后会生成model.json的配置文件
def gen_model_json(self,E:Optional[int]=0,block_size:Optional[list]=None):
json_dir = os.getenv('LMSLIM_TUNING_JSON', "None")
if json_dir != "None" and os.path.exists(json_dir):
#生成模型配置文件
# logger.info("model_tuning.json is at LMSLIM_TUNING_JSON:%s", json_dir)
config = {
"layers": {
"linear": {
"shapes": [],
"m_range":"None",
},
"moe": {
"shapes": [],
"m_range": "None",
"topk": self.topk
}
},
"quantization_config": {
"quant_method": self.quant_method,
"weight_block_size": "None"
}
}
# 处理 MoE shapes
for shape in self.moe_weight_shapes:
if len(shape) == 4: # 假设 MoE shape 是 [N1, N2,K] 格式
moe_config = {
"E": shape[0],
"N1": shape[1],
"N2": shape[2],
"K": shape[3], # 默认值
}
config["layers"]["moe"]["shapes"].append(moe_config)
for shape in self.weight_shapes:
config["layers"]["linear"]["shapes"].append(shape)
if block_size is not None:
config["quantization_config"]["weight_block_size"]=block_size
with open(json_dir+"/model.json", 'w') as f:
json.dump(config, f, indent=4)
# else:
# logger.info("LMSLIM_TUNING_JSON is not set")
def getspec_config(self,configs_dict,M,N,K):
if f"{M}_{N}_{K}" in configs_dict:
return configs_dict[f"{M}_{N}_{K}"]
else:
return None
def get_triton_cache(self,file_path,n,k):
#在非tuning的时候使用,当文件不存在则直接返回none
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
return None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value
return configs_dict
def get_w8a8json_name(self,n,k):
return self.triton_json_dir+f"/W8A8_{n}_{k}_{self.device_name}.json"
def get_blockint8_triton_cache(self,file_path,n,k,block_n,block_k):
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
return None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value
return configs_dict
def get_blockint8json_name(self,n,k,block_n,block_k):
return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{self.device_name}.json"
def get_moeint8json_name(self,E,N1,N2,K,TOPK,
block_size:Optional[list]=None,use_int4_w4a8:Optional[bool]=False):
if use_int4_w4a8:
if block_size is not None:
return self.triton_json_dir+f"/MOE_W4A8INT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W4A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
if block_size is not None:
return self.triton_json_dir+f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
def get_moeint8_triton_cache(self,file_path,E,N1,N2,K,TOPK):
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
return None
#把所有的cache解析成key:config的形式:[M_N_K]:[config1,config2]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value
return configs_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment