Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
59259b56
Commit
59259b56
authored
Nov 25, 2025
by
lizhigong
Browse files
Merge branch 'v0.5.4_rzc' into 'v0.5.4_dev'
fix compile and op issues See merge request OpenDAS/sglang!38
parents
1a73f6a3
263b5bde
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
276 additions
and
12 deletions
+276
-12
python/sglang/srt/_custom_ops.py
python/sglang/srt/_custom_ops.py
+12
-0
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
+3
-0
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
...ers/quantization/compressed_tensors/compressed_tensors.py
+29
-1
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+68
-8
python/sglang/srt/layers/quantization/slimquant_w4a8.py
python/sglang/srt/layers/quantization/slimquant_w4a8.py
+8
-2
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+1
-1
python/sglang/srt/utils/common.py
python/sglang/srt/utils/common.py
+155
-0
No files found.
python/sglang/srt/_custom_ops.py
View file @
59259b56
...
@@ -332,6 +332,18 @@ def rocblas_scaled_mm(a: torch.Tensor,
...
@@ -332,6 +332,18 @@ def rocblas_scaled_mm(a: torch.Tensor,
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
blaslt_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
m
=
a
.
shape
[
0
]
n
=
b
.
shape
[
0
]
k
=
a
.
shape
[
1
]
_
,
out
=
quant_ops
.
hipblaslt_w8a8_gemm
(
a
,
b
,
scale_a
,
scale_b
,
m
,
n
,
k
,
'NT'
,
out_dtype
)
return
out
def
triton_int8_gemm_helper
(
m
:
int
,
def
triton_int8_gemm_helper
(
m
:
int
,
n
:
int
,
n
:
int
,
k
:
int
,
k
:
int
,
...
...
python/sglang/srt/layers/moe/token_dispatcher/deepep.py
View file @
59259b56
...
@@ -725,6 +725,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
...
@@ -725,6 +725,7 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
self
.
packed_recv_count
=
self
.
handle
=
None
self
.
packed_recv_count
=
self
.
handle
=
None
return
combined_hidden_states
,
event
,
hook
return
combined_hidden_states
,
event
,
hook
@
torch
.
_dynamo
.
disable
()
def
_get_buffer
(
self
):
def
_get_buffer
(
self
):
DeepEPBuffer
.
set_dispatch_mode_as_low_latency
()
DeepEPBuffer
.
set_dispatch_mode_as_low_latency
()
return
DeepEPBuffer
.
get_deepep_buffer
(
return
DeepEPBuffer
.
get_deepep_buffer
(
...
@@ -805,6 +806,7 @@ class DeepEPDispatcher(BaseDispatcher):
...
@@ -805,6 +806,7 @@ class DeepEPDispatcher(BaseDispatcher):
)
)
self
.
_dispatch_intermediate_state
=
inner_state
self
.
_dispatch_intermediate_state
=
inner_state
@
torch
.
_dynamo
.
disable
()
def
dispatch_b
(
self
):
def
dispatch_b
(
self
):
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_A
,
_Stage
.
AFTER_DISPATCH_B
)
self
.
_update_stage
(
_Stage
.
AFTER_DISPATCH_A
,
_Stage
.
AFTER_DISPATCH_B
)
inner_state
=
self
.
_dispatch_intermediate_state
inner_state
=
self
.
_dispatch_intermediate_state
...
@@ -832,6 +834,7 @@ class DeepEPDispatcher(BaseDispatcher):
...
@@ -832,6 +834,7 @@ class DeepEPDispatcher(BaseDispatcher):
)
)
self
.
_combine_intermediate_state
=
inner_state
self
.
_combine_intermediate_state
=
inner_state
@
torch
.
_dynamo
.
disable
()
def
combine_b
(
self
):
def
combine_b
(
self
):
self
.
_update_stage
(
_Stage
.
AFTER_COMBINE_A
,
_Stage
.
INITIAL
)
self
.
_update_stage
(
_Stage
.
AFTER_COMBINE_A
,
_Stage
.
INITIAL
)
inner_state
=
self
.
_combine_intermediate_state
inner_state
=
self
.
_combine_intermediate_state
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py
View file @
59259b56
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
__future__
import
annotations
from
__future__
import
annotations
import
os
import
logging
import
logging
from
contextlib
import
suppress
from
contextlib
import
suppress
from
typing
import
Any
,
Dict
,
List
,
Literal
,
NamedTuple
,
Optional
,
Tuple
,
cast
from
typing
import
Any
,
Dict
,
List
,
Literal
,
NamedTuple
,
Optional
,
Tuple
,
cast
...
@@ -46,6 +46,9 @@ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
...
@@ -46,6 +46,9 @@ from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.unquant
import
UnquantizedLinearMethod
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
sglang.srt.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
sglang.srt
import
_custom_ops
as
ops
from
sglang.srt.utils
import
W8a8GetCacheJSON
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
"CompressedTensorsLinearMethod"
]
__all__
=
[
"CompressedTensorsLinearMethod"
]
...
@@ -590,8 +593,33 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
...
@@ -590,8 +593,33 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
def
__init__
(
self
,
quantization_config
:
CompressedTensorsConfig
):
def
__init__
(
self
,
quantization_config
:
CompressedTensorsConfig
):
self
.
quantization_config
=
quantization_config
self
.
quantization_config
=
quantization_config
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
n
=
layer
.
weight
.
shape
[
0
]
k
=
layer
.
weight
.
shape
[
1
]
if
self
.
w8a8_strategy
==
1
:
if
[
n
,
k
]
not
in
self
.
tritonsingleton
.
weight_shapes
:
self
.
tritonsingleton
.
weight_shapes
.
append
([
n
,
k
])
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
self
.
tritonsingleton
.
triton_json_dict
.
update
(
configs_dict
)
for
key
,
value
in
configs_dict
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
ops
.
triton_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
device
=
layer
.
weight
.
device
,
best_config
=
value
)
elif
self
.
w8a8_strategy
==
3
:
layer
.
weight
.
data
=
layer
.
weight
.
data
.
T
else
:
weight_data
=
layer
.
weight
.
data
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
layer
.
weight
.
data
=
_weight
self
.
tritonsingleton
.
gen_model_json
()
layer
.
scheme
.
process_weights_after_loading
(
layer
)
layer
.
scheme
.
process_weights_after_loading
(
layer
)
def
create_weights
(
def
create_weights
(
...
...
python/sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
59259b56
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
import
os
from
typing
import
Callable
,
Optional
from
typing
import
Callable
,
Optional
import
torch
import
torch
...
@@ -19,11 +20,13 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
...
@@ -19,11 +20,13 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
from
sglang.srt.layers.quantization.utils
import
requantize_with_max_scale
from
sglang.srt.layers.quantization.utils
import
requantize_with_max_scale
from
sglang.srt.utils
import
is_cuda
from
sglang.srt.utils
import
is_cuda
from
lmslim
import
quant_ops
from
sglang.srt
import
_custom_ops
as
ops
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
int8_scaled_mm
from
sgl_kernel
import
int8_scaled_mm
from
sglang.srt.utils
import
W8a8GetCacheJSON
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
class
CompressedTensorsW8A8Int8
(
CompressedTensorsScheme
):
class
CompressedTensorsW8A8Int8
(
CompressedTensorsScheme
):
...
@@ -33,6 +36,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -33,6 +36,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self
.
strategy
=
strategy
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
input_symmetric
=
input_symmetric
self
.
input_symmetric
=
input_symmetric
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'1'
))
# TODO
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
...
@@ -163,14 +167,70 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -163,14 +167,70 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
)
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
layer
.
register_parameter
(
"input_zero_point"
,
input_zero_point
)
@
torch
.
_dynamo
.
disable
()
def
apply_weights
(
def
apply_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# TODO: add cutlass_scaled_mm_azp support
# TODO: add cutlass_scaled_mm_azp support
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
x_q
,
x_scale
=
per_token_quant_int8
(
x
)
# return quant_ops.custom_scaled_mm(x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias)
# TODO: fix with lmslim/lightop
return
quant_ops
.
triton_scaled_mm
(
if
self
.
w8a8_strategy
==
1
:
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
m
=
x_q
.
shape
[
0
]
)
k
=
x_q
.
shape
[
1
]
n
=
layer
.
weight
.
shape
[
1
]
if
len
(
W8A8_TRITONJSON
.
triton_json_dict
)
==
0
:
best_config
=
None
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
:
if
m
<=
16
:
m_
=
m
elif
m
<=
64
:
m_
=
(
m
+
3
)
&
-
4
#取值到最近的4的倍数
elif
m
<=
160
:
m_
=
(
m
+
7
)
&
-
8
elif
m
<
200
:
#256
m_
=
160
elif
m
<
480
:
#512
m_
=
256
elif
m
<
960
:
#1024
m_
=
512
elif
m
<
2048
:
m_
=
1024
elif
m
<
4096
:
m_
=
2048
elif
m
<
6000
:
m_
=
4096
else
:
m_
=
8192
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
else
:
best_config
=
None
return
ops
.
triton_scaled_mm
(
x_q
,
layer
.
weight
,
x_scale
,
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
,
best_config
=
best_config
)
elif
self
.
w8a8_strategy
==
2
:
return
ops
.
cutlass_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
elif
self
.
w8a8_strategy
==
3
:
return
ops
.
blaslt_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
None
)
else
:
return
ops
.
rocblas_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
python/sglang/srt/layers/quantization/slimquant_w4a8.py
View file @
59259b56
...
@@ -15,7 +15,7 @@ from lmslim.layers.gemm.int8_utils import (
...
@@ -15,7 +15,7 @@ from lmslim.layers.gemm.int8_utils import (
per_token_group_quant_int8
,
per_token_group_quant_int8
,
per_token_quant_int8
)
per_token_quant_int8
)
from
sglang.srt
import
_custom_ops
as
ops
from
sglang.srt
import
_custom_ops
as
ops
from
vllm
.utils
import
W8a8GetCacheJSON
from
sglang.srt
.utils
import
W8a8GetCacheJSON
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
from
sglang.srt.layers.moe
import
MoeRunner
,
MoeRunnerBackend
,
MoeRunnerConfig
import
os
import
os
...
@@ -157,7 +157,6 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
...
@@ -157,7 +157,6 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
)
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
layer
.
register_parameter
(
"weight_scale"
,
weight_scale
)
@
torch
.
_dynamo
.
disable
()
# TODO: 性能优化需要lmslim/lightop配合
def
apply
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
...
@@ -227,6 +226,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
...
@@ -227,6 +226,13 @@ class SlimQuantW4A8Int8LinearMethod(LinearMethodBase):
scale_b
=
layer
.
weight_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
out_dtype
=
x
.
dtype
,
bias
=
bias
)
bias
=
bias
)
elif
self
.
w8a8_strategy
==
3
:
return
ops
.
blaslt_scaled_mm
(
x_q
,
layer
.
weight
,
scale_a
=
x_scale
,
scale_b
=
layer
.
weight_scale
,
out_dtype
=
x
.
dtype
,
bias
=
None
)
else
:
else
:
return
ops
.
rocblas_scaled_mm
(
x_q
,
return
ops
.
rocblas_scaled_mm
(
x_q
,
layer
.
weight
,
layer
.
weight
,
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
59259b56
...
@@ -653,7 +653,7 @@ class ForwardBatch:
...
@@ -653,7 +653,7 @@ class ForwardBatch:
bs
=
self
.
batch_size
,
bs
=
self
.
batch_size
,
)
)
else
:
else
:
logger
.
info
(
"SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES=0"
)
#
logger.info("SGLANG_CREATE_CHUNKED_PREFIX_CACHE_KV_INDICES=0")
create_chunked_prefix_cache_kv_indices
[(
self
.
batch_size
,)](
create_chunked_prefix_cache_kv_indices
[(
self
.
batch_size
,)](
self
.
req_to_token_pool
.
req_to_token
,
self
.
req_to_token_pool
.
req_to_token
,
self
.
req_pool_indices
,
self
.
req_pool_indices
,
...
...
python/sglang/srt/utils/common.py
View file @
59259b56
...
@@ -3528,3 +3528,158 @@ def cached_triton_kernel(key_fn=None):
...
@@ -3528,3 +3528,158 @@ def cached_triton_kernel(key_fn=None):
return
CachedKernel
(
fn
,
key_fn
)
return
CachedKernel
(
fn
,
key_fn
)
return
decorator
return
decorator
# from vllm
class
W8a8GetCacheJSON
:
_instance
=
None
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
(
W8a8GetCacheJSON
,
cls
).
__new__
(
cls
,
*
args
,
**
kwargs
)
cls
.
_instance
.
_initialize
()
return
cls
.
_instance
def
_initialize
(
self
):
current_folder_path
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
json_folder_path
=
current_folder_path
+
'/../../lmslim/configs/w8a8'
self
.
triton_json_dir
=
(
os
.
getenv
(
'TRITON_JSON_DIR'
,
json_folder_path
))
self
.
triton_json_dict
=
{}
self
.
triton_moejson_dict
=
{}
self
.
triton_json_list
=
[]
self
.
weight_shapes
=
[]
self
.
moe_weight_shapes
=
[]
arch_name
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
arch_cu
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
device_name
=
arch_name
+
'_'
+
str
(
arch_cu
)
+
'cu'
self
.
device_name
=
device_name
self
.
topk
=
1
self
.
quant_method
=
None
#析构函数,最后会生成model.json的配置文件
def
gen_model_json
(
self
,
E
:
Optional
[
int
]
=
0
,
block_size
:
Optional
[
list
]
=
None
):
json_dir
=
os
.
getenv
(
'LMSLIM_TUNING_JSON'
,
"None"
)
if
json_dir
!=
"None"
and
os
.
path
.
exists
(
json_dir
):
#生成模型配置文件
# logger.info("model_tuning.json is at LMSLIM_TUNING_JSON:%s", json_dir)
config
=
{
"layers"
:
{
"linear"
:
{
"shapes"
:
[],
"m_range"
:
"None"
,
},
"moe"
:
{
"shapes"
:
[],
"m_range"
:
"None"
,
"topk"
:
self
.
topk
}
},
"quantization_config"
:
{
"quant_method"
:
self
.
quant_method
,
"weight_block_size"
:
"None"
}
}
# 处理 MoE shapes
for
shape
in
self
.
moe_weight_shapes
:
if
len
(
shape
)
==
4
:
# 假设 MoE shape 是 [N1, N2,K] 格式
moe_config
=
{
"E"
:
shape
[
0
],
"N1"
:
shape
[
1
],
"N2"
:
shape
[
2
],
"K"
:
shape
[
3
],
# 默认值
}
config
[
"layers"
][
"moe"
][
"shapes"
].
append
(
moe_config
)
for
shape
in
self
.
weight_shapes
:
config
[
"layers"
][
"linear"
][
"shapes"
].
append
(
shape
)
if
block_size
is
not
None
:
config
[
"quantization_config"
][
"weight_block_size"
]
=
block_size
with
open
(
json_dir
+
"/model.json"
,
'w'
)
as
f
:
json
.
dump
(
config
,
f
,
indent
=
4
)
# else:
# logger.info("LMSLIM_TUNING_JSON is not set")
def
getspec_config
(
self
,
configs_dict
,
M
,
N
,
K
):
if
f
"
{
M
}
_
{
N
}
_
{
K
}
"
in
configs_dict
:
return
configs_dict
[
f
"
{
M
}
_
{
N
}
_
{
K
}
"
]
else
:
return
None
def
get_triton_cache
(
self
,
file_path
,
n
,
k
):
#在非tuning的时候使用,当文件不存在则直接返回none
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
return
None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_dict
[
configs_key
]
=
sub_value
return
configs_dict
def
get_w8a8json_name
(
self
,
n
,
k
):
return
self
.
triton_json_dir
+
f
"/W8A8_
{
n
}
_
{
k
}
_
{
self
.
device_name
}
.json"
def
get_blockint8_triton_cache
(
self
,
file_path
,
n
,
k
,
block_n
,
block_k
):
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
return
None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_dict
[
configs_key
]
=
sub_value
return
configs_dict
def
get_blockint8json_name
(
self
,
n
,
k
,
block_n
,
block_k
):
return
self
.
triton_json_dir
+
f
"/linear_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]_
{
self
.
device_name
}
.json"
def
get_moeint8json_name
(
self
,
E
,
N1
,
N2
,
K
,
TOPK
,
block_size
:
Optional
[
list
]
=
None
,
use_int4_w4a8
:
Optional
[
bool
]
=
False
):
if
use_int4_w4a8
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_BLOCKINT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W8A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
def
get_moeint8_triton_cache
(
self
,
file_path
,
E
,
N1
,
N2
,
K
,
TOPK
):
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
return
None
#把所有的cache解析成key:config的形式:[M_N_K]:[config1,config2]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_dict
[
configs_key
]
=
sub_value
return
configs_dict
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment