Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
7771c0ac
Commit
7771c0ac
authored
Feb 11, 2026
by
SAC_fanth
Browse files
接入channel、block triton 及channel-wise marlin
parent
9fdb8e3a
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
257 additions
and
32 deletions
+257
-32
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+3
-3
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+16
-10
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
...ation/compressed_tensors/compressed_tensors_moe_marlin.py
+204
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
...compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
+3
-2
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+13
-8
vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py
...executor/layers/quantization/kernels/scaled_mm/pytorch.py
+1
-0
vllm/utils/__init__.py
vllm/utils/__init__.py
+17
-8
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
7771c0ac
...
...
@@ -1725,7 +1725,7 @@ def fused_experts_impl(
else
:
cache13
=
torch
.
empty
(
M
*
top_k_num
*
max
(
N
,
K
if
not
use_nn_moe
else
w2
.
shape
[
2
]),
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
if
use_int8_w8a8
is
True
:
if
use_int8_w8a8
or
use_fp8_w8a8
:
return
fused_experts_impl_int8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
...
...
@@ -1735,8 +1735,8 @@ def fused_experts_impl(
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
False
,
use_int8_w8a8
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
per_channel_quant
=
per_channel_quant
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
7771c0ac
...
...
@@ -1109,22 +1109,28 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
use_nn_moe
:
bool
|
None
=
False
,
use_fused_gate
:
bool
|
None
=
False
,
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
not
self
.
is_monolithic
assert
self
.
kernel
is
not
None
return
self
.
kernel
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
inplace
=
self
.
use_inplace
,
from
vllm.model_executor.layers.fused_moe
import
fused_experts
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
global_num_experts
=
layer
.
global_num_experts
,
# TODO(rob): investigate the disable_expert_map introduced by:
# https://github.com/vllm-project/vllm/commit/84166fee9770e6fba71a96978b3e7d149392fb28 # noqa: E501
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
quant_config
=
self
.
moe_quant_config
,
use_fused_gate
=
use_fused_gate
,
use_nn_moe
=
False
,
)
@
property
...
...
vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe_marlin.py
View file @
7771c0ac
...
...
@@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe import (
)
try
:
from
lmslim.layers.fused_moe.fuse_moe_int8_marlin
import
fused_experts_impl_int8_marlin
from
lmslim.layers.fused_moe.fuse_moe_fp8_marlin
import
fused_experts_impl_fp8_marlin
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer the quantitative model of moe.
\n
"
)
...
...
@@ -35,9 +36,32 @@ logger = init_logger(__name__)
__all__
=
[
"CompressedTensorsW8A8Int8MarlinMoEMethod"
,
"CompressedTensorsW8A8FP8MarlinMoEMethod"
,
]
def
fp32_to_fp8_e4m3fn
(
t
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""更合理的FP32到Float8_e4m3fn转换,使用最近值而不是简单舍弃尾数"""
# torch.float8_e4m3fn的数值范围约[-448, 448]
fp8_min
,
fp8_max
=
-
448.0
,
448.0
t_clamped
=
t
.
clamp
(
min
=
fp8_min
,
max
=
fp8_max
)
# 保证不会下溢到0
# 转换前到float16再转fp8可能提升精度(float8实现本身通常通过float16做rounding)
t_fp16
=
t_clamped
.
to
(
torch
.
float16
)
return
t_fp16
.
to
(
torch
.
float8_e4m3fn
)
def
w8a8_fp8_nt_kpack2_marlin_weight
(
w8a8_w
,
# [size_n, size_k// 2 ]
k_tile
=
16
,
n_tile
=
16
,
):
size_n
,
size_k
=
w8a8_w
.
shape
assert
size_n
%
k_tile
==
0
and
size_k
%
n_tile
==
0
,
"k_tile / n_tile 必须能整除对应维度"
w8a8_w
=
w8a8_w
.
reshape
((
size_n
//
n_tile
,
n_tile
,
size_k
//
k_tile
,
k_tile
))
w8a8_w
=
w8a8_w
.
permute
((
0
,
2
,
1
,
3
)).
contiguous
()
w8a8_w
=
w8a8_w
.
reshape
((
size_n
//
k_tile
,
size_k
*
k_tile
))
return
w8a8_w
class
CompressedTensorsMarlinMoEMethod
(
FusedMoEMethodBase
):
def
__init_
(
self
,
moe
:
FusedMoEConfig
):
super
().
__init__
(
moe
)
...
...
@@ -52,12 +76,191 @@ class CompressedTensorsMarlinMoEMethod(FusedMoEMethodBase):
weight_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
input_quant
=
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
if
quant_config
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
if
quant_config
.
_is_fp8_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8FP8MarlinMoEMethod
(
quant_config
,
layer
.
moe_config
)
elif
quant_config
.
_is_dynamic_token_w8a8
(
weight_quant
,
input_quant
):
return
CompressedTensorsW8A8Int8MarlinMoEMethod
(
quant_config
,
layer
.
moe_config
)
else
:
raise
RuntimeError
(
f
"Slimquant_marlin does not support the FusedMoe scheme:
{
weight_quant
}
,
{
input_quant
}
"
)
class
CompressedTensorsW8A8FP8MarlinMoEMethod
(
CompressedTensorsMarlinMoEMethod
):
def
__init__
(
self
,
quant_config
:
"CompressedTensorsMarlinConfig"
,
# type: ignore # noqa E501
moe
:
FusedMoEConfig
):
self
.
quant_config
=
quant_config
super
().
__init__
(
moe
)
self
.
weight_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"weights"
)
self
.
input_quant
=
self
.
quant_config
.
target_scheme_map
[
"Linear"
].
get
(
"input_activations"
)
per_channel
=
(
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
and
self
.
input_quant
.
strategy
==
QuantizationStrategy
.
TOKEN
)
if
not
per_channel
:
raise
ValueError
(
"For FP8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f
"
{
self
.
weight_quant
}
,
{
self
.
input_quant
}
"
)
self
.
static_input_scales
=
not
self
.
input_quant
.
dynamic
if
self
.
static_input_scales
:
raise
ValueError
(
"For FP8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales."
)
self
.
fused_experts
=
self
.
fused_moe_forward
def
get_fused_moe_quant_config
(
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
FusedMoEQuantConfig
]:
return
None
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
num_experts
:
int
,
hidden_size
:
int
,
intermediate_size_per_partition
:
int
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
):
params_dtype
=
torch
.
float8_e4m3fn
# WEIGHTS
w13_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size_per_partition
,
hidden_size
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
w2_weight
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
intermediate_size_per_partition
,
dtype
=
params_dtype
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
# WEIGHT_SCALES
assert
self
.
weight_quant
.
strategy
==
QuantizationStrategy
.
CHANNEL
w13_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
2
*
intermediate_size_per_partition
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w13_weight_scale"
,
w13_weight_scale
)
w2_weight_scale
=
torch
.
nn
.
Parameter
(
torch
.
ones
(
num_experts
,
hidden_size
,
1
,
dtype
=
torch
.
float32
),
requires_grad
=
False
)
layer
.
register_parameter
(
"w2_weight_scale"
,
w2_weight_scale
)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs
.
update
(
{
"quant_method"
:
FusedMoeWeightScaleSupported
.
CHANNEL
.
value
})
set_weight_attrs
(
w13_weight_scale
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight_scale
,
extra_weight_attrs
)
# INPUT_SCALES
assert
not
self
.
static_input_scales
layer
.
w13_input_scale
=
None
layer
.
w2_input_scale
=
None
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
w1_marlin_list
=
[]
for
ii
in
range
(
layer
.
w13_weight
.
shape
[
0
]):
w1_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w13_weight
[
ii
])
w1_marlin_list
.
append
(
w1_marlin_in
.
float
()
if
w1_marlin_in
.
dtype
==
torch
.
float8_e4m3fn
else
w1_marlin_in
)
w1_marlin
=
torch
.
stack
(
w1_marlin_list
,
dim
=
0
)
w1_marlin
=
fp32_to_fp8_e4m3fn
(
w1_marlin
)
del
w1_marlin_list
w2_marlin_list
=
[]
for
ii
in
range
(
layer
.
w2_weight
.
shape
[
0
]):
w2_marlin_in
=
get_w8a8_int8_marlin_weights
(
layer
.
w2_weight
[
ii
])
w2_marlin_list
.
append
(
w2_marlin_in
.
float
()
if
w2_marlin_in
.
dtype
==
torch
.
float8_e4m3fn
else
w2_marlin_in
)
w2_marlin
=
torch
.
stack
(
w2_marlin_list
,
dim
=
0
)
w2_marlin
=
fp32_to_fp8_e4m3fn
(
w2_marlin
)
layer
.
w13_weight
=
Parameter
(
w1_marlin
,
requires_grad
=
False
)
layer
.
w2_weight
=
Parameter
(
w2_marlin
,
requires_grad
=
False
)
def
fused_moe_forward
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
):
return
fused_experts_impl_fp8_marlin
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
True
,
per_channel_quant
=
True
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
layer
.
w13_weight_scale
,
w2_scale
=
layer
.
w2_weight_scale
,
a1_scale
=
layer
.
w13_input_scale
,
a2_scale
=
layer
.
w2_input_scale
,
use_nn_moe
=
False
,
shared_output
=
shared_output
,
routed_scaling_factor
=
routed_scaling_factor
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
custom_routing_function
:
Optional
[
Callable
]
=
None
,
scoring_func
:
str
=
"softmax"
,
e_score_correction_bias
:
Optional
[
torch
.
Tensor
]
=
None
,
apply_router_weight_on_input
:
bool
=
False
,
activation
:
str
=
"silu"
,
enable_eplb
:
bool
=
False
,
use_nn_moe
:
Optional
[
bool
]
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
use_fused_gate
:
Optional
[
bool
]
=
False
,
expert_load_view
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_to_physical_map
:
Optional
[
torch
.
Tensor
]
=
None
,
logical_replica_count
:
Optional
[
torch
.
Tensor
]
=
None
,
shared_output
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
if
enable_eplb
:
raise
NotImplementedError
(
"EPLB not supported for "
"`CompressedTensorsW8A8Int8MoEMethod` yet."
)
return
self
.
fused_experts
(
layer
=
layer
,
x
=
x
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
activation
=
activation
,
routed_scaling_factor
=
routed_scaling_factor
,
shared_output
=
shared_output
,
)
class
CompressedTensorsW8A8Int8MarlinMoEMethod
(
CompressedTensorsMarlinMoEMethod
):
def
__init__
(
self
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py
View file @
7771c0ac
...
...
@@ -161,8 +161,9 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
)
if
envs
.
VLLM_W8A8_BACKEND
==
3
:
weight
=
weight
.
t
().
contiguous
()
else
:
weight
=
weight
.
t
()
# triton不用转置,torch需要
# else:
# weight = weight.t()
elif
self
.
strategy
==
QuantizationStrategy
.
BLOCK
:
assert
self
.
is_static_input_scheme
is
False
weight
,
weight_scale
=
process_fp8_weight_block_strategy
(
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
7771c0ac
...
...
@@ -1031,17 +1031,22 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
->
torch
.
Tensor
|
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
kernel
is
not
None
assert
not
self
.
is_monolithic
return
self
.
kernel
(
x
,
layer
.
w13_weight
,
layer
.
w2_weight
,
topk_weights
,
topk_ids
,
inplace
=
self
.
use_inplace
,
from
vllm.model_executor.layers.fused_moe
import
fused_experts
return
fused_experts
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_weights
=
topk_weights
,
topk_ids
=
topk_ids
,
inplace
=
True
,
activation
=
layer
.
activation
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
global_num_experts
=
layer
.
global_num_experts
,
expert_map
=
layer
.
expert_map
,
apply_router_weight_on_input
=
layer
.
apply_router_weight_on_input
,
quant_config
=
self
.
moe_quant_config
,
use_fused_gate
=
use_fused_gate
,
use_nn_moe
=
False
,
)
...
...
vllm/model_executor/layers/quantization/kernels/scaled_mm/pytorch.py
View file @
7771c0ac
...
...
@@ -14,6 +14,7 @@ from .ScaledMMLinearKernel import (
)
try
:
from
lmslim.quantize.quant_ops
import
hipblaslt_w8a8_channelwise_gemm
from
lmslim.layers.gemm.fp8_utils
import
triton_scaled_mm_fp8
except
ImportError
:
print
(
"INFO: Please updata lmslim if you want to use fp8_utils.
\n
"
)
from
vllm
import
envs
...
...
vllm/utils/__init__.py
View file @
7771c0ac
...
...
@@ -61,7 +61,7 @@ class W8a8GetCacheJSON:
self
.
moe_weight_shapes
=
[]
arch_name
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
arch_cu
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
self
.
cache_json_data
=
{}
device_name
=
arch_name
+
'_'
+
str
(
arch_cu
)
+
'cu'
self
.
device_name
=
device_name
self
.
topk
=
1
...
...
@@ -162,21 +162,30 @@ class W8a8GetCacheJSON:
def
get_blockint8json_name
(
self
,
n
,
k
,
block_n
,
block_k
):
return
self
.
triton_json_dir
+
f
"/linear_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]_
{
self
.
device_name
}
.json"
def
get_moeint8json_name
(
self
,
E
,
N1
,
N2
,
K
,
TOPK
,
block_size
:
list
|
None
=
None
,
use_int4_w4a8
:
bool
|
None
=
False
):
def
get_moeint8json_name
(
self
,
E
,
N1
,
N2
,
K
,
TOPK
,
block_size
:
list
|
None
=
None
,
use_int4_w4a8
:
bool
|
None
=
False
,
use_int8_w8a8
:
bool
|
None
=
False
):
if
use_int4_w4a8
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
elif
use_int8_w8a8
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_BLOCKINT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W
4
A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_W
8
A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_BLOCK
INT
8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_BLOCK
FP
8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W8A8
INT
8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_W8A8
FP
8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
def
get_moeint8_triton_cache
(
self
,
file_path
,
E
,
N1
,
N2
,
K
,
TOPK
):
if
file_path
in
self
.
cache_json_data
:
# 直接返回缓存数据,避免重复读取
return
self
.
cache_json_data
[
file_path
]
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
...
...
@@ -192,7 +201,7 @@ class W8a8GetCacheJSON:
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_dict
[
configs_key
]
=
sub_value
self
.
cache_json_data
[
file_path
]
=
configs_dict
return
configs_dict
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment