Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5b7f2c7d
Commit
5b7f2c7d
authored
Jan 21, 2026
by
SAC_fanth
Browse files
fused_moe_fp8接入lmslim
parent
c2e7f4ad
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
21 additions
and
417 deletions
+21
-417
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+3
-3
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+4
-408
vllm/utils/__init__.py
vllm/utils/__init__.py
+14
-6
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
5b7f2c7d
...
...
@@ -1791,7 +1791,7 @@ def fused_experts_impl(
device
=
hidden_states
.
device
,
dtype
=
hidden_states
.
dtype
)
if
use_int8_w8a8
is
True
:
if
use_int8_w8a8
or
use_fp8_w8a8
:
return
fused_experts_impl_int8
(
hidden_states
=
hidden_states
,
w1
=
w1
,
w2
=
w2
,
...
...
@@ -1801,8 +1801,8 @@ def fused_experts_impl(
inplace
=
inplace
,
activation
=
activation
,
apply_router_weight_on_input
=
apply_router_weight_on_input
,
use_fp8_w8a8
=
False
,
use_int8_w8a8
=
True
,
use_fp8_w8a8
=
use_fp8_w8a8
,
use_int8_w8a8
=
use_int8_w8a8
,
use_int8_w8a16
=
False
,
use_int4_w4a16
=
False
,
per_channel_quant
=
per_channel_quant
,
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
5b7f2c7d
...
...
@@ -21,6 +21,10 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
tl
,
triton
from
vllm.utils
import
cdiv
,
direct_register_custom_op
,
has_deep_gemm
try
:
from
lmslim.layers.gemm.fp8_utils
import
per_token_group_quant_fp8
,
w8a8_block_fp8_matmul
except
Exception
:
print
(
"INFO: Please updata lmslim if you want to use fp8_utils.
\n
"
)
logger
=
init_logger
(
__name__
)
...
...
@@ -256,330 +260,6 @@ def block_quant_to_tensor_quant(
return
x_q_tensor
,
scale
@
triton
.
jit
def
_per_token_group_quant_fp8
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
group_size
,
# Num columns of y
y_num_columns
,
y_row_stride
,
# Avoid to divide zero
eps
,
# Information for float8
fp8_min
,
fp8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row
=
y_num_columns
//
group_size
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
row
=
g_id
//
groups_per_row
row_g_id
=
g_id
%
groups_per_row
# Ensure offset calculations use int64 to prevent overflow
y_ptr_offset
=
(
row
.
to
(
tl
.
int64
)
*
y_row_stride
)
+
(
row_g_id
.
to
(
tl
.
int64
)
*
group_size
)
y_ptr
+=
y_ptr_offset
y_q_ptr_offset
=
g_id
.
to
(
tl
.
int64
)
*
group_size
y_q_ptr
+=
y_q_ptr_offset
y_s_ptr
+=
g_id
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
mask
=
cols
<
group_size
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
@
triton
.
jit
def
_per_token_group_quant_fp8_colmajor
(
# Pointers to inputs and output
y_ptr
,
y_q_ptr
,
y_s_ptr
,
group_size
,
# Num columns of y
y_num_columns
,
y_row_stride
,
# Stride from one column to the next of y_s
y_s_col_stride
,
# Avoid to divide zero
eps
,
# Information for float8
fp8_min
,
fp8_max
,
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
):
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
This function converts the tensor values into float8 values.
"""
groups_per_row
=
y_num_columns
//
group_size
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
row
=
g_id
//
groups_per_row
row_g_id
=
g_id
%
groups_per_row
# Ensure offset calculations use int64 to prevent overflow
y_ptr_offset
=
(
row
.
to
(
tl
.
int64
)
*
y_row_stride
)
+
(
row_g_id
.
to
(
tl
.
int64
)
*
group_size
)
y_ptr
+=
y_ptr_offset
y_q_ptr_offset
=
g_id
.
to
(
tl
.
int64
)
*
group_size
y_q_ptr
+=
y_q_ptr_offset
# Convert g_id the flattened block coordinate to 2D so we can index
# into the output y_scales matrix
blocks_per_row
=
y_num_columns
//
group_size
scale_col
=
g_id
%
blocks_per_row
scale_row
=
g_id
//
blocks_per_row
# Ensure offset calculation uses int64 for y_s_ptr
y_s_ptr_offset
=
(
scale_col
.
to
(
tl
.
int64
)
*
y_s_col_stride
)
+
scale_row
.
to
(
tl
.
int64
)
y_s_ptr
+=
y_s_ptr_offset
cols
=
tl
.
arange
(
0
,
BLOCK
)
# group_size <= BLOCK
mask
=
cols
<
group_size
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
y_s
=
_absmax
/
fp8_max
y_q
=
tl
.
clamp
(
y
/
y_s
,
fp8_min
,
fp8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
def
per_token_group_quant_fp8
(
x
:
torch
.
Tensor
,
group_size
:
int
,
eps
:
float
=
1e-10
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
column_major_scales
:
bool
=
False
,
out_q
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Function to perform per-token-group quantization on an input tensor `x`.
It converts the tensor values into signed float8 values and returns the
quantized tensor along with the scaling factor used for quantization.
Args:
x: The input tensor with ndim >= 2.
group_size: The group size used for quantization.
eps: The minimum to avoid dividing zero.
dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn`
is supported for now.
column_major_scales: Outputs scales in column major.
out_q: Optional output tensor. If not provided, function will create.
Returns:
tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the
scaling factor for quantization.
"""
dtype
=
current_platform
.
fp8_dtype
()
if
dtype
is
None
else
dtype
assert
(
x
.
shape
[
-
1
]
%
group_size
==
0
),
(
f
"the last dimension of `x`
{
x
.
shape
[
-
1
]
}
must be divisible "
f
"by `group_size`
{
group_size
}
"
)
assert
x
.
stride
(
-
1
)
==
1
,
"`x` groups must be contiguous"
finfo
=
torch
.
finfo
(
dtype
)
fp8_min
=
finfo
.
min
fp8_max
=
finfo
.
max
assert
out_q
is
None
or
out_q
.
shape
==
x
.
shape
x_q
=
out_q
if
x_q
is
None
:
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
if
column_major_scales
:
shape
=
(
x
.
shape
[
-
1
]
//
group_size
,
)
+
x
.
shape
[:
-
1
]
x_s
=
torch
.
empty
(
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float32
).
permute
(
-
1
,
-
2
)
else
:
shape
=
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,
)
x_s
=
torch
.
empty
(
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float32
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
if
column_major_scales
:
_per_token_group_quant_fp8_colmajor
[(
M
,
)](
x
,
x_q
,
x_s
,
group_size
,
x
.
shape
[
1
],
x
.
stride
(
0
),
x_s
.
stride
(
1
),
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
else
:
_per_token_group_quant_fp8
[(
M
,
)](
x
,
x_q
,
x_s
,
group_size
,
x
.
shape
[
1
],
x
.
stride
(
0
),
eps
,
fp8_min
=
fp8_min
,
fp8_max
=
fp8_max
,
BLOCK
=
BLOCK
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
return
x_q
,
x_s
@
triton
.
jit
def
_w8a8_block_fp8_matmul
(
# Pointers to inputs and output
A
,
B
,
C
,
As
,
Bs
,
# Shape for matmul
M
,
N
,
K
,
# Block size for block-wise quantization
group_n
,
group_k
,
# Stride for inputs and output
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
stride_As_m
,
stride_As_k
,
stride_Bs_k
,
stride_Bs_n
,
# Meta-parameters
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_N
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
GROUP_SIZE_M
:
tl
.
constexpr
,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B` with block-wise quantization, and
store the result in output tensor `C`.
"""
pid
=
tl
.
program_id
(
axis
=
0
)
num_pid_m
=
tl
.
cdiv
(
M
,
BLOCK_SIZE_M
)
num_pid_n
=
tl
.
cdiv
(
N
,
BLOCK_SIZE_N
)
num_pid_in_group
=
GROUP_SIZE_M
*
num_pid_n
group_id
=
pid
//
num_pid_in_group
first_pid_m
=
group_id
*
GROUP_SIZE_M
group_size_m
=
min
(
num_pid_m
-
first_pid_m
,
GROUP_SIZE_M
)
pid_m
=
first_pid_m
+
(
pid
%
group_size_m
)
pid_n
=
(
pid
%
num_pid_in_group
)
//
group_size_m
offs_am
=
(
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
))
%
M
offs_bn
=
(
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
))
%
N
offs_k
=
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
a_ptrs
=
A
+
(
offs_am
[:,
None
]
*
stride_am
+
offs_k
[
None
,
:]
*
stride_ak
)
b_ptrs
=
B
+
(
offs_k
[:,
None
]
*
stride_bk
+
offs_bn
[
None
,
:]
*
stride_bn
)
As_ptrs
=
As
+
offs_am
*
stride_As_m
offs_bsn
=
offs_bn
//
group_n
Bs_ptrs
=
Bs
+
offs_bsn
*
stride_Bs_n
accumulator
=
tl
.
zeros
((
BLOCK_SIZE_M
,
BLOCK_SIZE_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_SIZE_K
)):
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[:,
None
]
<
K
-
k
*
BLOCK_SIZE_K
,
other
=
0.0
)
k_start
=
k
*
BLOCK_SIZE_K
offs_ks
=
k_start
//
group_k
a_s
=
tl
.
load
(
As_ptrs
+
offs_ks
*
stride_As_k
)
b_s
=
tl
.
load
(
Bs_ptrs
+
offs_ks
*
stride_Bs_k
)
accumulator
+=
tl
.
dot
(
a
,
b
)
*
a_s
[:,
None
]
*
b_s
[
None
,
:]
a_ptrs
+=
BLOCK_SIZE_K
*
stride_ak
b_ptrs
+=
BLOCK_SIZE_K
*
stride_bk
if
C
.
dtype
.
element_ty
==
tl
.
bfloat16
:
c
=
accumulator
.
to
(
tl
.
bfloat16
)
elif
C
.
dtype
.
element_ty
==
tl
.
float16
:
c
=
accumulator
.
to
(
tl
.
float16
)
else
:
c
=
accumulator
.
to
(
tl
.
float32
)
offs_cm
=
pid_m
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
offs_cn
=
pid_n
*
BLOCK_SIZE_N
+
tl
.
arange
(
0
,
BLOCK_SIZE_N
)
c_ptrs
=
C
+
stride_cm
*
offs_cm
[:,
None
]
+
stride_cn
*
offs_cn
[
None
,
:]
c_mask
=
(
offs_cm
[:,
None
]
<
M
)
&
(
offs_cn
[
None
,
:]
<
N
)
tl
.
store
(
c_ptrs
,
c
,
mask
=
c_mask
)
@
functools
.
lru_cache
def
get_w8a8_block_fp8_configs
(
N
:
int
,
K
:
int
,
block_n
:
int
,
block_k
:
int
)
->
Optional
[
dict
[
int
,
Any
]]:
"""
Return optimized configurations for the w8a8 block fp8 kernel.
The return value will be a dictionary that maps an irregular grid of
batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the
kernel on a given batch size bs, the closest batch size in the grid should
be picked and the associated configuration chosen to invoke the kernel.
"""
# First look up if an optimized configuration is available in the configs
# directory
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
json_file_name
=
f
"N=
{
N
}
,K=
{
K
}
,device_name=
{
device_name
}
,dtype=fp8_w8a8,block_shape=[
{
block_n
}
,
{
block_k
}
].json"
# noqa: E501
config_file_path
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
"Using configuration from %s for W8A8 Block FP8 kernel."
,
config_file_path
,
)
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
# If no optimized configuration is available, we will use the default
# configuration
logger
.
warning
(
"Using default W8A8 Block FP8 kernel config. Performance might "
"be sub-optimal! Config file not found at %s"
,
config_file_path
,
)
return
None
def
hipblaslt_w8a8_block_fp8_matmul
(
...
...
@@ -605,88 +285,4 @@ def hipblaslt_w8a8_block_fp8_matmul(
enum_block_size
,
None
)
return
d
def
w8a8_block_fp8_matmul
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
As
:
torch
.
Tensor
,
Bs
:
torch
.
Tensor
,
block_size
:
list
[
int
],
output_dtype
:
torch
.
dtype
=
torch
.
float16
,
)
->
torch
.
Tensor
:
"""This function performs matrix multiplication with block-wise
quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
Args:
A: The input tensor, e.g., activation.
B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. It should
be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor.
Returns:
torch.Tensor: The result of matmul.
"""
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
and
A
.
is_contiguous
()
assert
triton
.
cdiv
(
A
.
shape
[
-
1
],
block_k
)
==
As
.
shape
[
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
Bs
.
ndim
==
2
N
,
K
=
B
.
shape
assert
triton
.
cdiv
(
N
,
block_n
)
==
Bs
.
shape
[
0
]
assert
triton
.
cdiv
(
K
,
block_k
)
==
Bs
.
shape
[
1
]
C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,
)
C
=
A
.
new_empty
(
C_shape
,
dtype
=
output_dtype
)
configs
=
get_w8a8_block_fp8_configs
(
N
,
K
,
block_size
[
0
],
block_size
[
1
])
if
configs
:
# Get the optimal config if there is one
config
=
configs
[
min
(
configs
.
keys
(),
key
=
lambda
x
:
abs
(
x
-
M
))]
else
:
# Default config
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0]
# BLOCK_SIZE_K must be divisible by block_size[1]
config
=
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
block_size
[
0
],
"BLOCK_SIZE_K"
:
block_size
[
1
],
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
2
,
}
def
grid
(
META
):
return
(
triton
.
cdiv
(
M
,
META
[
"BLOCK_SIZE_M"
])
*
triton
.
cdiv
(
N
,
META
[
"BLOCK_SIZE_N"
]),
)
_w8a8_block_fp8_matmul
[
grid
](
A
,
B
,
C
,
As
,
Bs
,
M
,
N
,
K
,
block_n
,
block_k
,
A
.
stride
(
-
2
),
A
.
stride
(
-
1
),
B
.
stride
(
1
),
B
.
stride
(
0
),
C
.
stride
(
-
2
),
C
.
stride
(
-
1
),
As
.
stride
(
-
2
),
As
.
stride
(
-
1
),
Bs
.
stride
(
1
),
Bs
.
stride
(
0
),
**
config
,
)
return
C
vllm/utils/__init__.py
View file @
5b7f2c7d
...
...
@@ -1958,7 +1958,7 @@ class W8a8GetCacheJSON:
self
.
moe_weight_shapes
=
[]
arch_name
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
.
split
(
':'
)[
0
]
arch_cu
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
()).
multi_processor_count
self
.
cache_json_data
=
{}
device_name
=
arch_name
+
'_'
+
str
(
arch_cu
)
+
'cu'
self
.
device_name
=
device_name
self
.
topk
=
1
...
...
@@ -2060,19 +2060,27 @@ class W8a8GetCacheJSON:
return
self
.
triton_json_dir
+
f
"/linear_
{
n
}
_
{
k
}
_block[
{
block_n
}
,
{
block_k
}
]_
{
self
.
device_name
}
.json"
def
get_moeint8json_name
(
self
,
E
,
N1
,
N2
,
K
,
TOPK
,
block_size
:
Optional
[
list
]
=
None
,
use_int4_w4a8
:
Optional
[
bool
]
=
False
):
block_size
:
Optional
[
list
]
=
None
,
use_int4_w4a8
:
Optional
[
bool
]
=
False
,
use_int8_w8a8
:
Optional
[
bool
]
=
False
):
if
use_int4_w4a8
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_W4A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
elif
use_int8_w8a8
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_BLOCKINT8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W8A8INT8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
if
block_size
is
not
None
:
return
self
.
triton_json_dir
+
f
"/MOE_BLOCK
INT
8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_BLOCK
FP
8[
{
block_size
[
0
]
}
,
{
block_size
[
1
]
}
]_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
else
:
return
self
.
triton_json_dir
+
f
"/MOE_W8A8
INT
8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
return
self
.
triton_json_dir
+
f
"/MOE_W8A8
FP
8_E=
{
E
}
_N1=
{
N1
}
_N2=
{
N2
}
_K=
{
K
}
_TOPK
{
TOPK
}
_
{
self
.
device_name
}
.json"
def
get_moeint8_triton_cache
(
self
,
file_path
,
E
,
N1
,
N2
,
K
,
TOPK
):
if
file_path
in
self
.
cache_json_data
:
# 直接返回缓存数据,避免重复读取
return
self
.
cache_json_data
[
file_path
]
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
...
...
@@ -2088,7 +2096,7 @@ class W8a8GetCacheJSON:
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_dict
[
configs_key
]
=
sub_value
self
.
cache_json_data
[
file_path
]
=
configs_dict
return
configs_dict
# Adapted from: https://stackoverflow.com/a/47212782/5082708
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment