Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
587c919c
Commit
587c919c
authored
May 06, 2026
by
王敏
Browse files
[Feat]初步实现PP+MTP
parent
ca9ce18d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
210 additions
and
32 deletions
+210
-32
vllm/model_executor/layers/fused_moe/modular_kernel.py
vllm/model_executor/layers/fused_moe/modular_kernel.py
+6
-5
vllm/model_executor/models/deepseek_mtp.py
vllm/model_executor/models/deepseek_mtp.py
+10
-0
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+16
-12
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+178
-15
No files found.
vllm/model_executor/layers/fused_moe/modular_kernel.py
View file @
587c919c
...
@@ -1081,11 +1081,12 @@ class FusedMoEModularKernel(torch.nn.Module):
...
@@ -1081,11 +1081,12 @@ class FusedMoEModularKernel(torch.nn.Module):
The _prepare method is a wrapper around self.prepare_finalize.prepare
The _prepare method is a wrapper around self.prepare_finalize.prepare
that handles DBO and async.
that handles DBO and async.
"""
"""
expected_m
=
(
if
self
.
fused_experts
.
num_dispatchers
is
not
None
:
hidden_states
.
shape
[
0
]
*
self
.
fused_experts
.
num_dispatchers
*
topk_ids
.
shape
[
1
]
expected_m
=
(
+
global_num_experts
hidden_states
.
shape
[
0
]
*
self
.
fused_experts
.
num_dispatchers
*
topk_ids
.
shape
[
1
]
)
//
global_num_experts
+
global_num_experts
self
.
fused_experts
.
set_expected_m
(
expected_m
)
)
//
global_num_experts
self
.
fused_experts
.
set_expected_m
(
expected_m
)
if
not
self
.
prepare_finalize
.
supports_async
():
if
not
self
.
prepare_finalize
.
supports_async
():
# We shouldn't be running an a2a kernel that doesn't
# We shouldn't be running an a2a kernel that doesn't
...
...
vllm/model_executor/models/deepseek_mtp.py
View file @
587c919c
...
@@ -341,6 +341,16 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP):
...
@@ -341,6 +341,16 @@ class DeepSeekMTP(nn.Module, DeepseekV2MixtureOfExperts, SupportsPP):
model_has_indexer
=
any
(
"indexer"
in
param_name
for
param_name
in
params_dict
.
keys
())
model_has_indexer
=
any
(
"indexer"
in
param_name
for
param_name
in
params_dict
.
keys
())
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"embed_tokens"
in
name
:
for
local_name
in
params_dict
.
keys
():
if
"embed_tokens"
in
local_name
:
param
=
params_dict
[
local_name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
break
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
continue
continue
...
...
vllm/v1/core/sched/scheduler.py
View file @
587c919c
...
@@ -355,14 +355,14 @@ class Scheduler(SchedulerInterface):
...
@@ -355,14 +355,14 @@ class Scheduler(SchedulerInterface):
# do not schedule another step for the same request while it still has
# do not schedule another step for the same request while it still has
# output placeholders for PP.
# output placeholders for PP.
# TODO: support PP + async scheduling without this limit
# TODO: support PP + async scheduling without this limit
if
self
.
use_pp
:
#
if self.use_pp:
if
(
envs
.
VLLM_USE_PP_BALANCE
and
#
if (envs.VLLM_USE_PP_BALANCE and
len
(
scheduled_new_reqs
)
+
len
(
scheduled_resumed_reqs
)
#
len(scheduled_new_reqs) + len(scheduled_resumed_reqs)
+
len
(
scheduled_running_reqs
)
>=
max_batch_running
):
#
+ len(scheduled_running_reqs) >= max_batch_running):
break
#
break
if
request
.
num_output_placeholders
>
0
:
#
if request.num_output_placeholders > 0:
req_index
+=
1
#
req_index += 1
continue
#
continue
if
(
if
(
request
.
num_output_placeholders
>
0
request
.
num_output_placeholders
>
0
...
@@ -1211,9 +1211,9 @@ class Scheduler(SchedulerInterface):
...
@@ -1211,9 +1211,9 @@ class Scheduler(SchedulerInterface):
# do not schedule another step for the same request while it still has
# do not schedule another step for the same request while it still has
# output placeholders for PP.
# output placeholders for PP.
# TODO: support PP + async scheduling without this limit
# TODO: support PP + async scheduling without this limit
if
self
.
use_pp
and
request
.
num_output_placeholders
>
0
:
#
if self.use_pp and request.num_output_placeholders > 0:
req_index
+=
1
#
req_index += 1
continue
#
continue
if
(
if
(
request
.
num_output_placeholders
>
0
request
.
num_output_placeholders
>
0
...
@@ -1617,7 +1617,11 @@ class Scheduler(SchedulerInterface):
...
@@ -1617,7 +1617,11 @@ class Scheduler(SchedulerInterface):
for
idx
,
req
in
enumerate
(
itertools
.
chain
(
running_reqs
,
resumed_reqs
)):
for
idx
,
req
in
enumerate
(
itertools
.
chain
(
running_reqs
,
resumed_reqs
)):
req_id
=
req
.
request_id
req_id
=
req
.
request_id
req_ids
.
append
(
req_id
)
req_ids
.
append
(
req_id
)
if
self
.
use_pp
:
#if self.use_pp:
# NOTE: In PP+async scheduling, we consume token ids via a direct GPU
# broadcast path (`input_batch.prev_sampled_token_ids`), so we can
# omit this payload.
if
self
.
use_pp
and
not
self
.
scheduler_config
.
async_scheduling
:
# When using PP, the scheduler sends the sampled tokens back,
# When using PP, the scheduler sends the sampled tokens back,
# because there's no direct communication between the first-
# because there's no direct communication between the first-
# stage worker and the last-stage worker. Otherwise, we don't
# stage worker and the last-stage worker. Otherwise, we don't
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
587c919c
...
@@ -966,12 +966,15 @@ class GPUModelRunner(
...
@@ -966,12 +966,15 @@ class GPUModelRunner(
# that case we include the resumed_req_ids in the unscheduled set so
# that case we include the resumed_req_ids in the unscheduled set so
# that they get cleared from the persistent batch before being re-scheduled
# that they get cleared from the persistent batch before being re-scheduled
# in the normal resumed request path.
# in the normal resumed request path.
#print(f"##################cached_req_ids:{cached_req_ids} scheduled_req_ids:{scheduled_req_ids} async scheduling:{self.use_async_scheduling}")
unscheduled_req_ids
=
cached_req_ids
-
(
scheduled_req_ids
-
resumed_req_ids
)
unscheduled_req_ids
=
cached_req_ids
-
(
scheduled_req_ids
-
resumed_req_ids
)
# NOTE(woosuk): The persistent batch optimization assumes that
# NOTE(woosuk): The persistent batch optimization assumes that
# consecutive batches contain mostly the same requests. If batches
# consecutive batches contain mostly the same requests. If batches
# have low request overlap (e.g., alternating between two distinct
# have low request overlap (e.g., alternating between two distinct
# sets of requests), this optimization becomes very inefficient.
# sets of requests), this optimization becomes very inefficient.
for
req_id
in
unscheduled_req_ids
:
for
req_id
in
unscheduled_req_ids
:
#print(f"#############################remove_request:{req_id}")
self
.
input_batch
.
remove_request
(
req_id
)
self
.
input_batch
.
remove_request
(
req_id
)
reqs_to_add
:
list
[
CachedRequestState
]
=
[]
reqs_to_add
:
list
[
CachedRequestState
]
=
[]
...
@@ -1077,25 +1080,32 @@ class GPUModelRunner(
...
@@ -1077,25 +1080,32 @@ class GPUModelRunner(
num_rejected
=
req_state
.
prev_num_draft_len
-
num_accepted
num_rejected
=
req_state
.
prev_num_draft_len
-
num_accepted
num_computed_tokens
-=
num_rejected
num_computed_tokens
-=
num_rejected
req_state
.
output_token_ids
.
extend
([
-
1
]
*
num_accepted
)
req_state
.
output_token_ids
.
extend
([
-
1
]
*
num_accepted
)
#print(f"#############################req_id:{req_id} num_accepted:{num_accepted}")
# Update the cached states.
# Update the cached states.
req_state
.
num_computed_tokens
=
num_computed_tokens
req_state
.
num_computed_tokens
=
num_computed_tokens
if
not
is_last_rank
:
if
not
is_last_rank
:
# When using PP, the scheduler sends the sampled tokens back,
if
not
req_data
.
new_token_ids
:
# because there's no direct communication between the first-
# Async scheduled PP: Sampled tokens propagated via GPU broadcast.
# stage worker and the last-stage worker.
new_token_ids
:
list
[
int
]
=
[]
new_token_ids
=
req_data
.
new_token_ids
[
i
]
else
:
# Add the sampled token(s) from the previous step (if any).
# Non-async scheduling with PP: The scheduler sends
# This doesn't include "unverified" tokens like spec tokens.
# sampled token ids back because there's no direct communication
num_new_tokens
=
(
# between the first-stage worker and the last-stage worker.
num_computed_tokens
+
len
(
new_token_ids
)
-
req_state
.
num_tokens
new_token_ids
=
req_data
.
new_token_ids
[
i
]
)
# Add the sampled token(s) from the previous step (if any).
if
num_new_tokens
==
1
:
# This doesn't include "unverified" tokens like spec tokens.
# Avoid slicing list in most common case.
num_new_tokens
=
(
req_state
.
output_token_ids
.
append
(
new_token_ids
[
-
1
])
num_computed_tokens
+
len
(
new_token_ids
)
-
req_state
.
num_tokens
elif
num_new_tokens
>
0
:
)
req_state
.
output_token_ids
.
extend
(
new_token_ids
[
-
num_new_tokens
:])
if
num_new_tokens
==
1
:
# Avoid slicing list in most common case.
req_state
.
output_token_ids
.
append
(
new_token_ids
[
-
1
])
elif
num_new_tokens
>
0
:
req_state
.
output_token_ids
.
extend
(
new_token_ids
[
-
num_new_tokens
:]
)
elif
num_output_tokens
<
len
(
req_state
.
output_token_ids
):
elif
num_output_tokens
<
len
(
req_state
.
output_token_ids
):
# Some output tokens were discarded due to a sync-KV-load
# Some output tokens were discarded due to a sync-KV-load
# failure. Align the cached state.
# failure. Align the cached state.
...
@@ -1431,6 +1441,8 @@ class GPUModelRunner(
...
@@ -1431,6 +1441,8 @@ class GPUModelRunner(
prev_common_req_indices_tensor
=
torch
.
tensor
(
prev_common_req_indices_tensor
=
torch
.
tensor
(
prev_common_req_indices
,
dtype
=
torch
.
int64
,
pin_memory
=
self
.
pin_memory
prev_common_req_indices
,
dtype
=
torch
.
int64
,
pin_memory
=
self
.
pin_memory
).
to
(
self
.
device
,
non_blocking
=
True
)
).
to
(
self
.
device
,
non_blocking
=
True
)
#print(f"###############sampled_tokens_index_tensor:{sampled_tokens_index_tensor} prev_common_req_indices_tensor:{prev_common_req_indices_tensor} prev_sampled_token_ids:{self.input_batch.prev_sampled_token_ids}")
self
.
input_ids
.
gpu
.
scatter_
(
self
.
input_ids
.
gpu
.
scatter_
(
dim
=
0
,
dim
=
0
,
index
=
sampled_tokens_index_tensor
,
index
=
sampled_tokens_index_tensor
,
...
@@ -3938,6 +3950,7 @@ class GPUModelRunner(
...
@@ -3938,6 +3950,7 @@ class GPUModelRunner(
scheduler_output
,
clear_metadata
=
clear_kv_metadata
scheduler_output
,
clear_metadata
=
clear_kv_metadata
)
as
kv_connector_output
,
)
as
kv_connector_output
,
):
):
#print(f"####################execute model input_ids:{input_ids.tolist()}")
model_output
=
self
.
_model_forward
(
model_output
=
self
.
_model_forward
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
...
@@ -4028,7 +4041,17 @@ class GPUModelRunner(
...
@@ -4028,7 +4041,17 @@ class GPUModelRunner(
self
.
kv_connector_output
=
None
self
.
kv_connector_output
=
None
if
self
.
execute_model_state
is
None
:
if
self
.
execute_model_state
is
None
:
# Nothing to do (PP non-final rank case), output isn't used.
# receive sampled token ids from the last PP rank.
if
self
.
use_async_scheduling
and
get_pp_group
().
world_size
>
1
:
if
not
self
.
speculative_config
:
self
.
_pp_receive_prev_sampled_token_ids_to_input_batch
()
else
:
self
.
_draft_token_ids
=
None
self
.
_draft_token_req_ids
=
None
self
.
input_batch
.
prev_sampled_token_ids
=
None
self
.
_pp_receive_prev_sampled_token_ids_and_valid_sampled_tokens_count
()
self
.
_pp_receive_draft_token_ids
()
if
not
kv_connector_output
:
if
not
kv_connector_output
:
return
None
# type: ignore[return-value]
return
None
# type: ignore[return-value]
...
@@ -4070,6 +4093,12 @@ class GPUModelRunner(
...
@@ -4070,6 +4093,12 @@ class GPUModelRunner(
sampler_output
.
sampled_token_ids
,
scheduler_output
sampler_output
.
sampled_token_ids
,
scheduler_output
)
)
pp
=
get_pp_group
()
if
self
.
use_async_scheduling
and
pp
.
world_size
>
1
and
pp
.
is_last_rank
and
not
self
.
speculative_config
:
self
.
_pp_broadcast_prev_sampled_token_ids
(
sampler_output
.
sampled_token_ids
)
self
.
_draft_token_ids
=
None
self
.
_draft_token_ids
=
None
self
.
_draft_token_req_ids
=
None
self
.
_draft_token_req_ids
=
None
self
.
input_batch
.
prev_sampled_token_ids
=
None
self
.
input_batch
.
prev_sampled_token_ids
=
None
...
@@ -4090,6 +4119,11 @@ class GPUModelRunner(
...
@@ -4090,6 +4119,11 @@ class GPUModelRunner(
)
)
self
.
_copy_draft_token_ids_to_cpu
(
scheduler_output
)
self
.
_copy_draft_token_ids_to_cpu
(
scheduler_output
)
# broadcast draft_token_ids to non-last pp rank
if
self
.
use_async_scheduling
and
pp
.
world_size
>
1
and
pp
.
is_last_rank
:
self
.
_pp_broadcast_draft_token_ids
(
self
.
_draft_token_ids
)
spec_config
=
self
.
speculative_config
spec_config
=
self
.
speculative_config
propose_drafts_after_bookkeeping
=
False
propose_drafts_after_bookkeeping
=
False
if
spec_config
is
not
None
:
if
spec_config
is
not
None
:
...
@@ -4207,6 +4241,129 @@ class GPUModelRunner(
...
@@ -4207,6 +4241,129 @@ class GPUModelRunner(
return
async_output
return
async_output
def
_pp_broadcast_prev_sampled_token_ids
(
self
,
sampled_token_ids
:
torch
.
Tensor
)
->
None
:
"""Broadcast sampled token ids (GPU) from last PP stage"""
pp
=
get_pp_group
()
assert
pp
.
is_last_rank
# `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
assert
sampled_token_ids
.
dim
()
==
2
and
sampled_token_ids
.
shape
[
-
1
]
==
1
,
(
"PP+async expects sampled_token_ids to have shape [num_reqs, 1]"
)
torch
.
distributed
.
broadcast
(
sampled_token_ids
,
src
=
pp
.
rank
,
group
=
pp
.
device_group
)
def
_pp_receive_prev_sampled_token_ids_to_input_batch
(
self
)
->
None
:
"""Receive sampled token ids broadcast from last PP stage"""
pp
=
get_pp_group
()
assert
not
pp
.
is_last_rank
num_reqs
=
self
.
input_batch
.
num_reqs
# `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
recv
=
torch
.
empty
((
num_reqs
,
1
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
distributed
.
broadcast
(
recv
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
)
self
.
input_batch
.
prev_sampled_token_ids
=
recv
# construct `prev_req_id_to_index` here so `_prepare_input_ids`
# can map req_id -> previous batch row
discard_req_indices
=
np
.
nonzero
(
self
.
discard_request_mask
.
np
[:
num_reqs
])[
0
]
discard_req_indices_set
=
set
(
discard_req_indices
)
prev_req_id_to_index
:
dict
[
str
,
int
]
=
{}
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
if
i
in
discard_req_indices_set
:
continue
prev_req_id_to_index
[
req_id
]
=
i
# PP+async scheduling: advance per-request local cached output length by
# appending a placeholder (-1) token id.
if
(
req_state
:
=
self
.
requests
.
get
(
req_id
))
is
not
None
:
req_state
.
output_token_ids
.
append
(
-
1
)
self
.
input_batch
.
prev_req_id_to_index
=
prev_req_id_to_index
def
_pp_broadcast_prev_sampled_token_ids_and_valid_sampled_tokens_count
(
self
,
sampled_token_ids
:
torch
.
Tensor
,
valid_sampled_tokens_count
:
torch
.
Tensor
)
->
None
:
"""Broadcast sampled token ids (GPU) from last PP stage"""
pp
=
get_pp_group
()
assert
pp
.
is_last_rank
sampled_token_ids
=
sampled_token_ids
.
view
(
-
1
,
1
)
valid_sampled_tokens_count
=
valid_sampled_tokens_count
.
view
(
-
1
,
1
)
#print(f"##################pp broadcast sampled_token_ids:{sampled_token_ids.tolist()} valid_sampled_tokens_count:{valid_sampled_tokens_count.tolist()}")
# `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
assert
sampled_token_ids
.
dim
()
==
2
and
sampled_token_ids
.
shape
[
-
1
]
==
1
,
(
"PP+async expects sampled_token_ids to have shape [num_reqs, 1]"
)
assert
valid_sampled_tokens_count
.
dim
()
==
2
and
sampled_token_ids
.
shape
[
-
1
]
==
1
,
(
"PP+async expects valid_sampled_tokens_count to have shape [num_reqs, 1]"
)
data
=
torch
.
cat
([
sampled_token_ids
,
valid_sampled_tokens_count
],
dim
=-
1
)
torch
.
distributed
.
broadcast
(
data
,
src
=
pp
.
rank
,
group
=
pp
.
device_group
)
def
_pp_receive_prev_sampled_token_ids_and_valid_sampled_tokens_count
(
self
)
->
None
:
"""Receive sampled token ids broadcast from last PP stage"""
pp
=
get_pp_group
()
assert
not
pp
.
is_last_rank
num_reqs
=
self
.
input_batch
.
num_reqs
# `prev_sampled_token_ids` is expected to have shape [num_reqs, 1].
recv
=
torch
.
empty
((
num_reqs
,
2
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
distributed
.
broadcast
(
recv
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
)
prev_sampled_token_ids
=
recv
[:,
:
1
].
squeeze
(
1
)
valid_sampled_tokens_count
=
recv
[:,
-
1
]
self
.
_copy_valid_sampled_token_count
(
prev_sampled_token_ids
,
valid_sampled_tokens_count
)
#print(f"#############pp recv prev_sampled_token_ids:{prev_sampled_token_ids.tolist()} valid_sampled_tokens_count:{valid_sampled_tokens_count.tolist()}")
# construct `prev_req_id_to_index` here so `_prepare_input_ids`
# can map req_id -> previous batch row
discard_req_indices
=
np
.
nonzero
(
self
.
discard_request_mask
.
np
[:
num_reqs
])[
0
]
discard_req_indices_set
=
set
(
discard_req_indices
)
prev_req_id_to_index
:
dict
[
str
,
int
]
=
{}
for
i
,
req_id
in
enumerate
(
self
.
input_batch
.
req_ids
):
if
i
in
discard_req_indices_set
:
continue
prev_req_id_to_index
[
req_id
]
=
i
# PP+async scheduling: advance per-request local cached output length by
# appending a placeholder (-1*(self.num_spec_tokens + 1)) token id.
# if (req_state := self.requests.get(req_id)) is not None:
# #req_state.output_token_ids.append(-1)
# req_state.output_token_ids.extend([-1] * (self.num_spec_tokens + 1))
self
.
input_batch
.
prev_req_id_to_index
=
prev_req_id_to_index
def
_pp_broadcast_draft_token_ids
(
self
,
draft_token_ids
:
torch
.
Tensor
)
->
None
:
"""Broadcast sampled token ids (GPU) from last PP stage"""
pp
=
get_pp_group
()
assert
pp
.
is_last_rank
draft_token_ids
=
draft_token_ids
.
to
(
torch
.
int32
)
# `draft_token_ids` is expected to have shape [num_reqs, num_spec_tokens].
assert
draft_token_ids
.
dim
()
==
2
and
draft_token_ids
.
shape
[
-
1
]
==
self
.
num_spec_tokens
,
(
"PP+async expects sampled_token_ids to have shape [num_reqs, num_spec_tokens]"
)
#print(f"####################broadcast draft_token_ids:{draft_token_ids}")
torch
.
distributed
.
broadcast
(
draft_token_ids
,
src
=
pp
.
rank
,
group
=
pp
.
device_group
)
def
_pp_receive_draft_token_ids
(
self
)
->
None
:
"""Receive sampled token ids broadcast from last PP stage"""
pp
=
get_pp_group
()
assert
not
pp
.
is_last_rank
num_reqs
=
self
.
input_batch
.
num_reqs
# `prev_sampled_token_ids` is expected to have shape [num_reqs, num_spec_tokens].
recv
=
torch
.
empty
((
num_reqs
,
self
.
num_spec_tokens
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
torch
.
distributed
.
broadcast
(
recv
,
src
=
pp
.
last_rank
,
group
=
pp
.
device_group
)
#print(f"####################pp recv draft_token_ids:{recv} num_spec_tokens:{self.num_spec_tokens}")
self
.
_draft_token_ids
=
recv
def
take_draft_token_ids
(
self
)
->
DraftTokenIds
|
None
:
def
take_draft_token_ids
(
self
)
->
DraftTokenIds
|
None
:
if
not
self
.
num_spec_tokens
or
not
self
.
_draft_token_req_ids
:
if
not
self
.
num_spec_tokens
or
not
self
.
_draft_token_req_ids
:
return
None
return
None
...
@@ -4379,6 +4536,12 @@ class GPUModelRunner(
...
@@ -4379,6 +4536,12 @@ class GPUModelRunner(
self
.
discard_request_mask
.
gpu
,
self
.
discard_request_mask
.
gpu
,
)
)
)
)
# broadcast next_token_ids and valid_sampled_tokens_count to non-last pp rank
pp
=
get_pp_group
()
if
self
.
use_async_scheduling
and
pp
.
world_size
>
1
and
pp
.
is_last_rank
:
self
.
_pp_broadcast_prev_sampled_token_ids_and_valid_sampled_tokens_count
(
next_token_ids
,
valid_sampled_tokens_count
)
self
.
_copy_valid_sampled_token_count
(
self
.
_copy_valid_sampled_token_count
(
next_token_ids
,
valid_sampled_tokens_count
next_token_ids
,
valid_sampled_tokens_count
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment