Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4fca01b8
"csrc/vscode:/vscode.git/clone" did not exist on "5d8e93f67a1bf5f96213ffe7e7f64633a8c0e8ea"
Commit
4fca01b8
authored
Apr 18, 2026
by
wangmin6
Committed by
zhangzbb
Apr 18, 2026
Browse files
[Perf]优化EP低延迟模式下调度,消除调度空泡
parent
8c96d505
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
205 additions
and
25 deletions
+205
-25
vllm/forward_context.py
vllm/forward_context.py
+2
-1
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
.../model_executor/layers/fused_moe/batched_deep_gemm_moe.py
+169
-6
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
...l_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
+8
-10
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+1
-1
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+13
-0
vllm/v1/worker/dp_utils.py
vllm/v1/worker/dp_utils.py
+2
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+10
-6
No files found.
vllm/forward_context.py
View file @
4fca01b8
...
@@ -340,7 +340,8 @@ def set_forward_context(
...
@@ -340,7 +340,8 @@ def set_forward_context(
forward_start_time
=
time
.
perf_counter
()
forward_start_time
=
time
.
perf_counter
()
dp_metadata
:
DPMetadata
|
None
=
None
dp_metadata
:
DPMetadata
|
None
=
None
if
vllm_config
.
parallel_config
.
data_parallel_size
>
1
and
(
if
vllm_config
.
parallel_config
.
data_parallel_size
>
1
and
\
envs
.
VLLM_ALL2ALL_BACKEND
!=
"deepep_low_latency"
and
(
attn_metadata
is
not
None
or
num_tokens
is
not
None
attn_metadata
is
not
None
or
num_tokens
is
not
None
):
):
# If num_tokens_across_dp hasn't already been initialized, then
# If num_tokens_across_dp hasn't already been initialized, then
...
...
vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py
View file @
4fca01b8
...
@@ -3,6 +3,9 @@
...
@@ -3,6 +3,9 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
import
vllm.model_executor.layers.fused_moe.modular_kernel
as
mk
from
vllm.forward_context
import
get_forward_context
,
is_forward_context_available
from
vllm.forward_context
import
get_forward_context
,
is_forward_context_available
...
@@ -33,8 +36,10 @@ from vllm.utils.deep_gemm import (
...
@@ -33,8 +36,10 @@ from vllm.utils.deep_gemm import (
from
vllm.utils.math_utils
import
cdiv
,
round_up
from
vllm.utils.math_utils
import
cdiv
,
round_up
from
vllm.utils.import_utils
import
has_deep_gemm
from
vllm.utils.import_utils
import
has_deep_gemm
from
vllm.model_executor.layers.activation
import
SiluAndMul
from
lightop
import
fuse_silu_mul_quant_ep
from
lightop
import
fuse_silu_mul_quant_ep
from
lmslim.layers.gemm.int8_utils
import
per_token_quant_int8
if
has_deep_gemm
():
if
has_deep_gemm
():
from
deepgemm
import
m_grouped_w8a8_gemm_nt_masked
from
deepgemm
import
m_grouped_w8a8_gemm_nt_masked
else
:
else
:
...
@@ -45,6 +50,161 @@ else:
...
@@ -45,6 +50,161 @@ else:
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
# ==============================================
# MOE Grouped GEMM Triton内核 (int8量化 + 专家并行)
# 输入布局:All2All后 -> [E, M, K] / [E, N, K]
# 输出:[E, M, N] 直接写入传入的output张量
# ==============================================
@
triton
.
jit
def
moe_grouped_gemm_kernel
(
# 指针
A_ptr
,
B_ptr
,
A_scale_ptr
,
B_scale_ptr
,
token_counts_ptr
,
output_ptr
,
# 维度步长 (Batch/E维度步长, M/Token步长, N/Out通道步长, K/特征步长)
stride_A_E
,
stride_A_M
,
stride_A_K
,
stride_B_E
,
stride_B_N
,
stride_B_K
,
stride_A_scale_E
,
stride_A_scale_M
,
stride_B_scale_E
,
stride_B_scale_N
,
stride_out_E
,
stride_out_M
,
stride_out_N
,
# 固定维度
E
:
tl
.
constexpr
,
# 专家总数
M
:
tl
.
constexpr
,
# 每个专家最大Token数
N
:
tl
.
constexpr
,
# 每个专家输出维度
K
:
tl
.
constexpr
,
# 输入特征维度
# 分块参数 (T自动调优)
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
):
# ===================== 1. 专家ID + 计算坐标 =====================
# 程序ID对应:专家ID(E) + Token分块(M) + 输出分块(N)
pid_e
=
tl
.
program_id
(
0
)
# 专家维度 (0~E-1)
pid_m
=
tl
.
program_id
(
1
)
# Token分块维度
pid_n
=
tl
.
program_id
(
2
)
# 输出分块维度
# 当前专家实际需要计算的Token数量
token_cnt
=
tl
.
load
(
token_counts_ptr
+
pid_e
)
# 超出实际Token数直接退出 (动态Token数)
if
pid_m
*
BLOCK_M
>=
token_cnt
:
return
# ===================== 2. 计算当前分块的内存偏移 =====================
# 输入A [E, M, K]
A_base
=
A_ptr
+
pid_e
*
stride_A_E
# 权重B [E, N, K]
B_base
=
B_ptr
+
pid_e
*
stride_B_E
# Scale
A_scale_base
=
A_scale_ptr
+
pid_e
*
stride_A_scale_E
B_scale_base
=
B_scale_ptr
+
pid_e
*
stride_B_scale_E
# 输出 [E, M, N]
out_base
=
output_ptr
+
pid_e
*
stride_out_E
# 分块坐标
offs_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
# 内存索引
a_ptrs
=
A_base
+
(
offs_m
[:,
None
]
*
stride_A_M
+
offs_k
[
None
,
:]
*
stride_A_K
)
b_ptrs
=
B_base
+
(
offs_n
[:,
None
]
*
stride_B_N
+
offs_k
[
None
,
:]
*
stride_B_K
)
# ===================== 3. 初始化累加器 =====================
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
# ===================== 4. K维度循环计算GEMM (int8矩阵乘) =====================
for
k
in
range
(
0
,
K
,
BLOCK_K
):
# 加载int8数据 (保持int8精度)
a
=
tl
.
load
(
a_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
,
other
=
0.0
)
b
=
tl
.
load
(
b_ptrs
,
mask
=
offs_k
[
None
,
:]
<
K
-
k
,
other
=
0.0
)
# 矩阵乘累加
acc
+=
tl
.
dot
(
a
,
tl
.
trans
(
b
))
# B: [N,K] -> 转置为[K,N]
# 指针步进
a_ptrs
+=
BLOCK_K
*
stride_A_K
b_ptrs
+=
BLOCK_K
*
stride_B_K
# ===================== 5. int8反量化 (Per-Token + Per-Output Channel) =====================
# 加载当前专家的scale
a_scale
=
tl
.
load
(
A_scale_base
+
offs_m
*
stride_A_scale_M
)
# [BLOCK_M]
b_scale
=
tl
.
load
(
B_scale_base
+
offs_n
*
stride_B_scale_N
)
# [BLOCK_N]
# 反量化:out = (int8_mm) * A_scale * B_scale
result
=
acc
*
a_scale
[:,
None
]
*
b_scale
[
None
,
:]
# ===================== 6. 写入输出 [E, M, N] =====================
out_ptrs
=
out_base
+
(
offs_m
[:,
None
]
*
stride_out_M
+
offs_n
[
None
,
:]
*
stride_out_N
)
# 掩码:只写有效Token + 有效输出通道
mask_m
=
offs_m
<
token_cnt
mask_n
=
offs_n
<
N
mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
tl
.
store
(
out_ptrs
,
result
,
mask
=
mask
)
# ==============================================
# 包装函数 (对外调用接口,自动处理步长/启动网格)
# ==============================================
def
moe_grouped_gemm
(
A
:
torch
.
Tensor
,
# [E, M, K]
B
:
torch
.
Tensor
,
# [E, N, K] int8
A_scale
:
torch
.
Tensor
,
# [E, M, 1]
B_scale
:
torch
.
Tensor
,
# [E, N, 1]
token_counts
:
torch
.
Tensor
,
# [E]
output
:
torch
.
Tensor
,
# [E, M, N] (传入,直接写入)
):
# 维度校验
E
,
M
,
K
=
A
.
shape
_
,
N
,
_
=
B
.
shape
assert
B
.
shape
==
(
E
,
N
,
K
)
assert
A_scale
.
shape
==
(
E
,
M
,
1
)
assert
B_scale
.
shape
==
(
E
,
N
,
1
)
assert
token_counts
.
shape
==
(
E
,)
assert
output
.
shape
==
(
E
,
M
,
N
)
# 设备统一
assert
A
.
device
==
B
.
device
==
A_scale
.
device
==
B_scale
.
device
==
token_counts
.
device
==
output
.
device
assert
A
.
is_cuda
# 自动分块大小 (适配主流GPU)
BLOCK_M
=
64
BLOCK_N
=
64
BLOCK_K
=
64
# 计算网格:[E, ceil(M/BLOCK_M), ceil(N/BLOCK_N)]
grid
=
(
E
,
triton
.
cdiv
(
M
,
BLOCK_M
),
triton
.
cdiv
(
N
,
BLOCK_N
),
)
# 启动内核
moe_grouped_gemm_kernel
[
grid
](
# 数据指针
A
,
B
,
A_scale
,
B_scale
,
token_counts
,
output
,
# 步长 (按最后一维连续的张量自动计算)
stride_A_E
=
A
.
stride
(
0
),
stride_A_M
=
A
.
stride
(
1
),
stride_A_K
=
A
.
stride
(
2
),
stride_B_E
=
B
.
stride
(
0
),
stride_B_N
=
B
.
stride
(
1
),
stride_B_K
=
B
.
stride
(
2
),
stride_A_scale_E
=
A_scale
.
stride
(
0
),
stride_A_scale_M
=
A_scale
.
stride
(
1
),
stride_B_scale_E
=
B_scale
.
stride
(
0
),
stride_B_scale_N
=
B_scale
.
stride
(
1
),
stride_out_E
=
output
.
stride
(
0
),
stride_out_M
=
output
.
stride
(
1
),
stride_out_N
=
output
.
stride
(
2
),
# 固定维度
E
=
E
,
M
=
M
,
N
=
N
,
K
=
K
,
# 分块参数
BLOCK_M
=
BLOCK_M
,
BLOCK_N
=
BLOCK_N
,
BLOCK_K
=
BLOCK_K
,
)
return
output
def
scales_shape_stride_dtype
(
def
scales_shape_stride_dtype
(
E
:
int
,
T
:
int
,
G
:
int
,
quant_scale_fmt
:
DeepGemmQuantScaleFMT
E
:
int
,
T
:
int
,
G
:
int
,
quant_scale_fmt
:
DeepGemmQuantScaleFMT
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
)
->
tuple
[
tuple
[
int
,
...],
tuple
[
int
,
...],
torch
.
dtype
]:
...
@@ -297,6 +457,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -297,6 +457,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
self
.
N
=
N
self
.
N
=
N
self
.
K
=
K
self
.
K
=
K
self
.
act_fn
=
SiluAndMul
()
@
staticmethod
@
staticmethod
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
def
activation_format
()
->
mk
.
FusedMoEActivationFormat
:
...
@@ -414,7 +575,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -414,7 +575,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace2
:
torch
.
Tensor
,
workspace2
:
torch
.
Tensor
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
expert_tokens_meta
:
mk
.
ExpertTokensMetadata
|
None
,
apply_router_weight_on_input
:
bool
,
apply_router_weight_on_input
:
bool
,
use_nn_moe
:
bool
|
None
=
False
,
**
_
):
):
assert
expert_tokens_meta
is
not
None
assert
expert_tokens_meta
is
not
None
expert_num_tokens
=
expert_tokens_meta
.
expert_num_tokens
expert_num_tokens
=
expert_tokens_meta
.
expert_num_tokens
...
@@ -436,11 +597,13 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
...
@@ -436,11 +597,13 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace1
=
_resize_cache
(
workspace13
,
(
E
,
max_num_tokens
,
N
))
workspace1
=
_resize_cache
(
workspace13
,
(
E
,
max_num_tokens
,
N
))
expected_m
=
self
.
estimate_expected_m
(
# expected_m = self.estimate_expected_m(
global_num_experts
=
global_num_experts
,
# global_num_experts=global_num_experts,
max_tokens_per_expert
=
max_num_tokens
,
# max_tokens_per_expert=max_num_tokens,
topk
=
topk_ids
.
size
(
-
1
),
# topk=topk_ids.size(-1),
)
# )
expected_m
=
self
.
get_expected_m
()
if
self
.
quant_config
.
use_fp8_w8a16
or
self
.
quant_config
.
use_fp8_w8a8
:
if
self
.
quant_config
.
use_fp8_w8a16
or
self
.
quant_config
.
use_fp8_w8a8
:
fp8_m_grouped_gemm_nt_masked
(
fp8_m_grouped_gemm_nt_masked
(
...
...
vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py
View file @
4fca01b8
...
@@ -297,21 +297,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
...
@@ -297,21 +297,19 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# Dispatch
# Dispatch
dispatch_topk_ids
=
self
.
_map_global_to_physical_ids
(
topk_ids
)
dispatch_topk_ids
=
self
.
_map_global_to_physical_ids
(
topk_ids
)
quant_type
=
0
if
self
.
use_int8_dispatch
:
quant_type
=
1
elif
self
.
use_fp8_dispatch
:
quant_type
=
2
expert_x
,
expert_num_tokens
,
handle
,
_
,
hook
=
self
.
buffer
.
low_latency_dispatch
(
expert_x
,
expert_num_tokens
,
handle
,
_
,
hook
=
self
.
buffer
.
low_latency_dispatch
(
a1
,
a1
,
dispatch_topk_ids
,
dispatch_topk_ids
,
self
.
max_tokens_per_rank
,
self
.
max_tokens_per_rank
,
num_experts
,
num_experts
,
use_fp8
=
self
.
use_fp8_dispatch
or
self
.
use_int8_dispatch
,
quant_type
=
quant_type
,
use_int8
=
self
.
use_int8_dispatch
,
fp8_round_scale
=
False
,
round_scale
=
self
.
use_ue8m0_dispatch
,
use_ue8m0
=
self
.
use_ue8m0_dispatch
,
**
(
dict
(
use_nvfp4
=
True
)
if
use_nvfp4
else
dict
()),
**
(
dict
(
x_global_scale
=
qc_a1_gscale_or_scale
)
if
qc_a1_gscale_or_scale
is
not
None
else
dict
()
),
async_finish
=
False
,
async_finish
=
False
,
return_recv_hook
=
True
,
return_recv_hook
=
True
,
)
)
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
4fca01b8
...
@@ -853,7 +853,7 @@ class FusedMoE(CustomOp):
...
@@ -853,7 +853,7 @@ class FusedMoE(CustomOp):
def
use_dp_chunking
(
self
)
->
bool
:
def
use_dp_chunking
(
self
)
->
bool
:
return
(
return
(
self
.
moe_parallel_config
.
use_pplx_kernels
self
.
moe_parallel_config
.
use_pplx_kernels
or
self
.
moe_parallel_config
.
use_deepep_ll_kernels
#
or self.moe_parallel_config.use_deepep_ll_kernels
or
self
.
moe_parallel_config
.
use_mori_kernels
or
self
.
moe_parallel_config
.
use_mori_kernels
or
(
self
.
dp_size
>
1
and
self
.
use_flashinfer_cutlass_kernels
)
or
(
self
.
dp_size
>
1
and
self
.
use_flashinfer_cutlass_kernels
)
)
and
envs
.
VLLM_ENABLE_MOE_DP_CHUNK
)
and
envs
.
VLLM_ENABLE_MOE_DP_CHUNK
...
...
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
4fca01b8
...
@@ -406,6 +406,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
...
@@ -406,6 +406,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
self
.
max_num_tokens
=
max_num_tokens
self
.
max_num_tokens
=
max_num_tokens
self
.
num_dispatchers
=
num_dispatchers
self
.
num_dispatchers
=
num_dispatchers
self
.
expected_m
=
max_num_tokens
@
staticmethod
@
staticmethod
def
expects_unquantized_inputs
(
def
expects_unquantized_inputs
(
...
@@ -774,6 +775,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
...
@@ -774,6 +775,12 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
chooses to do weight application.
chooses to do weight application.
"""
"""
raise
NotImplementedError
raise
NotImplementedError
def
set_expected_m
(
self
,
expected_m
):
self
.
expected_m
=
expected_m
def
get_expected_m
(
self
):
return
self
.
expected_m
def
_slice_scales
(
def
_slice_scales
(
...
@@ -1074,6 +1081,12 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -1074,6 +1081,12 @@ class FusedMoEModularKernel(torch.nn.Module):
The _prepare method is a wrapper around self.prepare_finalize.prepare
The _prepare method is a wrapper around self.prepare_finalize.prepare
that handles DBO and async.
that handles DBO and async.
"""
"""
expected_m
=
(
hidden_states
.
shape
[
0
]
*
self
.
fused_experts
.
num_dispatchers
*
topk_ids
.
shape
[
1
]
+
global_num_experts
)
//
global_num_experts
self
.
fused_experts
.
set_expected_m
(
expected_m
)
if
not
self
.
prepare_finalize
.
supports_async
():
if
not
self
.
prepare_finalize
.
supports_async
():
# We shouldn't be running an a2a kernel that doesn't
# We shouldn't be running an a2a kernel that doesn't
# support async prepare/finalize
# support async prepare/finalize
...
...
vllm/v1/worker/dp_utils.py
View file @
4fca01b8
...
@@ -6,6 +6,7 @@ import numpy as np
...
@@ -6,6 +6,7 @@ import numpy as np
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
import
vllm.envs
as
envs
from
vllm.config
import
ParallelConfig
from
vllm.config
import
ParallelConfig
from
vllm.distributed.parallel_state
import
get_dp_group
from
vllm.distributed.parallel_state
import
get_dp_group
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -208,7 +209,7 @@ def coordinate_batch_across_dp(
...
@@ -208,7 +209,7 @@ def coordinate_batch_across_dp(
]
]
"""
"""
if
parallel_config
.
data_parallel_size
==
1
:
if
parallel_config
.
data_parallel_size
==
1
or
envs
.
VLLM_ALL2ALL_BACKEND
==
"deepep_low_latency"
:
# Early exit.
# Early exit.
return
False
,
None
,
cudagraph_mode
return
False
,
None
,
cudagraph_mode
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
4fca01b8
...
@@ -183,6 +183,7 @@ from .utils import (
...
@@ -183,6 +183,7 @@ from .utils import (
sanity_check_mm_encoder_outputs
,
sanity_check_mm_encoder_outputs
,
)
)
from
vllm.v1.spec_decode.utils
import
DraftProbs
from
vllm.v1.spec_decode.utils
import
DraftProbs
from
vllm.utils.torch_utils
import
async_tensor_h2d
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
from
vllm.model_executor.model_loader.tensorizer
import
TensorizerConfig
...
@@ -4789,9 +4790,6 @@ class GPUModelRunner(
...
@@ -4789,9 +4790,6 @@ class GPUModelRunner(
inputs_embeds
=
self
.
inputs_embeds
.
gpu
[:
num_tokens_padded
]
inputs_embeds
=
self
.
inputs_embeds
.
gpu
[:
num_tokens_padded
]
model_kwargs
=
self
.
_init_model_kwargs
()
model_kwargs
=
self
.
_init_model_kwargs
()
else
:
else
:
self
.
input_ids
.
gpu
[:
num_tokens_padded
]
=
torch
.
randint
(
0
,
self
.
model_config
.
get_vocab_size
(),
(
num_tokens_padded
,),
dtype
=
torch
.
int32
)
input_ids
=
self
.
input_ids
.
gpu
[:
num_tokens_padded
]
input_ids
=
self
.
input_ids
.
gpu
[:
num_tokens_padded
]
inputs_embeds
=
None
inputs_embeds
=
None
...
@@ -4904,9 +4902,15 @@ class GPUModelRunner(
...
@@ -4904,9 +4902,15 @@ class GPUModelRunner(
self
.
eplb_step
(
is_dummy
=
True
,
is_profile
=
is_profile
)
self
.
eplb_step
(
is_dummy
=
True
,
is_profile
=
is_profile
)
logit_indices
=
np
.
cumsum
(
num_scheduled_tokens
)
-
1
logit_indices
=
np
.
cumsum
(
num_scheduled_tokens
)
-
1
logit_indices_device
=
torch
.
from_numpy
(
logit_indices
).
to
(
# logit_indices_device = torch.from_numpy(logit_indices).to(
self
.
device
,
non_blocking
=
True
# self.device, non_blocking=True
)
# )
logit_indices
=
logit_indices
.
tolist
()
logit_indices_device
=
async_tensor_h2d
(
logit_indices
,
dtype
=
torch
.
int32
,
target_device
=
self
.
device
,
pin_memory
=
True
)
return
hidden_states
,
hidden_states
[
logit_indices_device
]
return
hidden_states
,
hidden_states
[
logit_indices_device
]
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment