Commit 59259b56 authored by lizhigong's avatar lizhigong
Browse files

Merge branch 'v0.5.4_rzc' into 'v0.5.4_dev'

fix compile and op issues

See merge request OpenDAS/sglang!38
parents 1a73f6a3 263b5bde
...@@ -332,6 +332,18 @@ def rocblas_scaled_mm(a: torch.Tensor, ...@@ -332,6 +332,18 @@ def rocblas_scaled_mm(a: torch.Tensor,
return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias) return quant_ops.rocblas_scaled_mm_nn(a, b, scale_a, scale_b, out_dtype, bias)
def blaslt_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
m = a.shape[0]
n = b.shape[0]
k = a.shape[1]
_, out = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a, scale_b, m, n, k, 'NT', out_dtype)
return out
def triton_int8_gemm_helper(m: int, def triton_int8_gemm_helper(m: int,
n: int, n: int,
k: int, k: int,
......
...@@ -725,6 +725,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase): ...@@ -725,6 +725,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
self.packed_recv_count = self.handle = None self.packed_recv_count = self.handle = None
return combined_hidden_states, event, hook return combined_hidden_states, event, hook
@torch._dynamo.disable()
def _get_buffer(self): def _get_buffer(self):
DeepEPBuffer.set_dispatch_mode_as_low_latency() DeepEPBuffer.set_dispatch_mode_as_low_latency()
return DeepEPBuffer.get_deepep_buffer( return DeepEPBuffer.get_deepep_buffer(
...@@ -805,6 +806,7 @@ class DeepEPDispatcher(BaseDispatcher): ...@@ -805,6 +806,7 @@ class DeepEPDispatcher(BaseDispatcher):
) )
self._dispatch_intermediate_state = inner_state self._dispatch_intermediate_state = inner_state
@torch._dynamo.disable()
def dispatch_b(self): def dispatch_b(self):
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B) self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
inner_state = self._dispatch_intermediate_state inner_state = self._dispatch_intermediate_state
...@@ -832,6 +834,7 @@ class DeepEPDispatcher(BaseDispatcher): ...@@ -832,6 +834,7 @@ class DeepEPDispatcher(BaseDispatcher):
) )
self._combine_intermediate_state = inner_state self._combine_intermediate_state = inner_state
@torch._dynamo.disable()
def combine_b(self): def combine_b(self):
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL) self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
inner_state = self._combine_intermediate_state inner_state = self._combine_intermediate_state
......
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from __future__ import annotations from __future__ import annotations
import os
import logging import logging
from contextlib import suppress from contextlib import suppress
from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, cast from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, cast
...@@ -46,6 +46,9 @@ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod ...@@ -46,6 +46,9 @@ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
from sglang.srt import _custom_ops as ops
from sglang.srt.utils import W8a8GetCacheJSON
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
__all__ = ["CompressedTensorsLinearMethod"] __all__ = ["CompressedTensorsLinearMethod"]
...@@ -590,8 +593,33 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -590,8 +593,33 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def __init__(self, quantization_config: CompressedTensorsConfig): def __init__(self, quantization_config: CompressedTensorsConfig):
self.quantization_config = quantization_config self.quantization_config = quantization_config
self.tritonsingleton= W8a8GetCacheJSON()
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1'))
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
n=layer.weight.shape[0]
k=layer.weight.shape[1]
if self.w8a8_strategy==1:
if [n,k] not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append([n,k])
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
if configs_dict:
self.tritonsingleton.triton_json_dict.update(configs_dict)
for key, value in configs_dict.items():
m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
elif self.w8a8_strategy==3:
layer.weight.data = layer.weight.data.T
else:
weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1)
layer.weight.data=_weight
self.tritonsingleton.gen_model_json()
layer.scheme.process_weights_after_loading(layer) layer.scheme.process_weights_after_loading(layer)
def create_weights( def create_weights(
......
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os
from typing import Callable, Optional from typing import Callable, Optional
import torch import torch
...@@ -19,11 +20,13 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import ( ...@@ -19,11 +20,13 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
from lmslim.layers.gemm.int8_utils import per_token_quant_int8 from lmslim.layers.gemm.int8_utils import per_token_quant_int8
from sglang.srt.layers.quantization.utils import requantize_with_max_scale from sglang.srt.layers.quantization.utils import requantize_with_max_scale
from sglang.srt.utils import is_cuda from sglang.srt.utils import is_cuda
from lmslim import quant_ops
from sglang.srt import _custom_ops as ops
_is_cuda = is_cuda() _is_cuda = is_cuda()
if _is_cuda: if _is_cuda:
from sgl_kernel import int8_scaled_mm from sgl_kernel import int8_scaled_mm
from sglang.srt.utils import W8a8GetCacheJSON
W8A8_TRITONJSON=W8a8GetCacheJSON()
class CompressedTensorsW8A8Int8(CompressedTensorsScheme): class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...@@ -33,6 +36,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -33,6 +36,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self.strategy = strategy self.strategy = strategy
self.is_static_input_scheme = is_static_input_scheme self.is_static_input_scheme = is_static_input_scheme
self.input_symmetric = input_symmetric self.input_symmetric = input_symmetric
self.w8a8_strategy=int(os.getenv('W8A8_SUPPORT_METHODS', '1')) # TODO
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
...@@ -163,14 +167,70 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme): ...@@ -163,14 +167,70 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
) )
layer.register_parameter("input_zero_point", input_zero_point) layer.register_parameter("input_zero_point", input_zero_point)
@torch._dynamo.disable()
def apply_weights( def apply_weights(
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: add cutlass_scaled_mm_azp support # TODO: add cutlass_scaled_mm_azp support
x_q, x_scale = per_token_quant_int8(x) x_q, x_scale = per_token_quant_int8(x)
# return quant_ops.custom_scaled_mm(x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias)
# TODO: fix with lmslim/lightop
return quant_ops.triton_scaled_mm( if self.w8a8_strategy==1:
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias m=x_q.shape[0]
) k=x_q.shape[1]
n=layer.weight.shape[1]
if len(W8A8_TRITONJSON.triton_json_dict)==0:
best_config=None
elif f"1_{n}_{k}" in W8A8_TRITONJSON.triton_json_dict:
if m<=16:
m_=m
elif m<=64:
m_= (m + 3) & -4 #取值到最近的4的倍数
elif m<=160:
m_=(m + 7) & -8
elif m<200: #256
m_=160
elif m<480: #512
m_=256
elif m<960: #1024
m_=512
elif m<2048:
m_=1024
elif m<4096:
m_=2048
elif m<6000:
m_=4096
else:
m_=8192
best_config=W8A8_TRITONJSON.triton_json_dict[f"{m_}_{n}_{k}"]
else:
best_config=None
return ops.triton_scaled_mm(
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias, best_config=best_config
)
elif self.w8a8_strategy==2:
return ops.cutlass_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
elif self.w8a8_strategy==3:
return ops.blaslt_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=None)
else:
return ops.rocblas_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias)
...@@ -15,7 +15,7 @@ from lmslim.layers.gemm.int8_utils import ( ...@@ -15,7 +15,7 @@ from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8, per_token_group_quant_int8,
per_token_quant_int8) per_token_quant_int8)
from sglang.srt import _custom_ops as ops from sglang.srt import _custom_ops as ops
from vllm.utils import W8a8GetCacheJSON from sglang.srt.utils import W8a8GetCacheJSON
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
import os import os
...@@ -157,7 +157,6 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -157,7 +157,6 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
) )
layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_scale", weight_scale)
@torch._dynamo.disable() # TODO: 性能优化需要lmslim/lightop配合
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -227,6 +226,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase): ...@@ -227,6 +226,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
scale_b=layer.weight_scale, scale_b=layer.weight_scale,
out_dtype=x.dtype, out_dtype=x.dtype,
bias=bias) bias=bias)
elif self.w8a8_strategy==3:
return ops.blaslt_scaled_mm(x_q,
layer.weight,
scale_a=x_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=None)
else: else:
return ops.rocblas_scaled_mm(x_q, return ops.rocblas_scaled_mm(x_q,
layer.weight, layer.weight,
......
...@@ -653,7 +653,7 @@ class ForwardBatch: ...@@ -653,7 +653,7 @@ class ForwardBatch:
bs = self.batch_size, bs = self.batch_size,
) )
else: else:
logger.info("SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES=0") # logger.info("SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES=0")
create_chunked_prefix_cache_kv_indices[(self.batch_size,)]( create_chunked_prefix_cache_kv_indices[(self.batch_size,)](
self.req_to_token_pool.req_to_token, self.req_to_token_pool.req_to_token,
self.req_pool_indices, self.req_pool_indices,
......
...@@ -3528,3 +3528,158 @@ def cached_triton_kernel(key_fn=None): ...@@ -3528,3 +3528,158 @@ def cached_triton_kernel(key_fn=None):
return CachedKernel(fn, key_fn) return CachedKernel(fn, key_fn)
return decorator return decorator
# from vllm
class W8a8GetCacheJSON:
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(W8a8GetCacheJSON, cls).__new__(cls, *args, **kwargs)
cls._instance._initialize()
return cls._instance
def _initialize(self):
current_folder_path = os.path.dirname(os.path.abspath(__file__))
json_folder_path=current_folder_path+'/../../lmslim/configs/w8a8'
self.triton_json_dir=(os.getenv('TRITON_JSON_DIR', json_folder_path))
self.triton_json_dict={}
self.triton_moejson_dict={}
self.triton_json_list=[]
self.weight_shapes=[]
self.moe_weight_shapes=[]
arch_name = torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0]
arch_cu = torch.cuda.get_device_properties(torch.cuda.current_device()).multi_processor_count
device_name =arch_name+'_'+str(arch_cu)+'cu'
self.device_name=device_name
self.topk=1
self.quant_method=None
#析构函数,最后会生成model.json的配置文件
def gen_model_json(self,E:Optional[int]=0,block_size:Optional[list]=None):
json_dir = os.getenv('LMSLIM_TUNING_JSON', "None")
if json_dir != "None" and os.path.exists(json_dir):
#生成模型配置文件
# logger.info("model_tuning.json is at LMSLIM_TUNING_JSON:%s", json_dir)
config = {
"layers": {
"linear": {
"shapes": [],
"m_range":"None",
},
"moe": {
"shapes": [],
"m_range": "None",
"topk": self.topk
}
},
"quantization_config": {
"quant_method": self.quant_method,
"weight_block_size": "None"
}
}
# 处理 MoE shapes
for shape in self.moe_weight_shapes:
if len(shape) == 4: # 假设 MoE shape 是 [N1, N2,K] 格式
moe_config = {
"E": shape[0],
"N1": shape[1],
"N2": shape[2],
"K": shape[3], # 默认值
}
config["layers"]["moe"]["shapes"].append(moe_config)
for shape in self.weight_shapes:
config["layers"]["linear"]["shapes"].append(shape)
if block_size is not None:
config["quantization_config"]["weight_block_size"]=block_size
with open(json_dir+"/model.json", 'w') as f:
json.dump(config, f, indent=4)
# else:
# logger.info("LMSLIM_TUNING_JSON is not set")
def getspec_config(self,configs_dict,M,N,K):
if f"{M}_{N}_{K}" in configs_dict:
return configs_dict[f"{M}_{N}_{K}"]
else:
return None
def get_triton_cache(self,file_path,n,k):
#在非tuning的时候使用,当文件不存在则直接返回none
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
return None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value
return configs_dict
def get_w8a8json_name(self,n,k):
return self.triton_json_dir+f"/W8A8_{n}_{k}_{self.device_name}.json"
def get_blockint8_triton_cache(self,file_path,n,k,block_n,block_k):
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
return None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value
return configs_dict
def get_blockint8json_name(self,n,k,block_n,block_k):
return self.triton_json_dir+f"/linear_{n}_{k}_block[{block_n},{block_k}]_{self.device_name}.json"
def get_moeint8json_name(self,E,N1,N2,K,TOPK,
block_size:Optional[list]=None,use_int4_w4a8:Optional[bool]=False):
if use_int4_w4a8:
if block_size is not None:
return self.triton_json_dir+f"/MOE_W4A8INT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W4A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
if block_size is not None:
return self.triton_json_dir+f"/MOE_BLOCKINT8[{block_size[0]},{block_size[1]}]_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
else:
return self.triton_json_dir+f"/MOE_W8A8INT8_E={E}_N1={N1}_N2={N2}_K={K}_TOPK{TOPK}_{self.device_name}.json"
def get_moeint8_triton_cache(self,file_path,E,N1,N2,K,TOPK):
cache_json_file=file_path
if os.path.exists(file_path):
#try:
with open(cache_json_file, 'r') as file:
cachedata = json.load(file)
else:
return None
#把所有的cache解析成key:config的形式:[M_N_K]:[config1,config2]
configs_dict={}
for key, value in cachedata.items():
for sub_key, sub_value in value.items():
configs_key= f"{sub_key}_{key}"
configs_dict[configs_key]=sub_value
return configs_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment