Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
6266c57b
Unverified
Commit
6266c57b
authored
May 14, 2025
by
youkaichao
Committed by
GitHub
May 14, 2025
Browse files
[core][distributed] add ep group and all2all interface (#18077)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
754b699c
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
234 additions
and
41 deletions
+234
-41
vllm/distributed/device_communicators/all2all.py
vllm/distributed/device_communicators/all2all.py
+93
-0
vllm/distributed/device_communicators/base_device_communicator.py
...tributed/device_communicators/base_device_communicator.py
+25
-1
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+37
-2
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+59
-2
vllm/distributed/utils.py
vllm/distributed/utils.py
+7
-2
vllm/envs.py
vllm/envs.py
+5
-0
vllm/model_executor/layers/fused_moe/layer.py
vllm/model_executor/layers/fused_moe/layer.py
+5
-33
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-1
No files found.
vllm/distributed/device_communicators/all2all.py
0 → 100644
View file @
6266c57b
# SPDX-License-Identifier: Apache-2.0
import
torch
from
vllm.forward_context
import
get_forward_context
class
All2AllBase
:
def
__init__
(
self
,
cpu_group
,
model
):
self
.
cpu_group
=
cpu_group
# compute some common properties
from
vllm.distributed.parallel_state
import
(
get_dp_group
,
get_ep_group
,
get_tp_group
,
in_the_same_node_as
)
# all2all lives in ep group, which is merged from dp and tp group
self
.
dp_group
=
get_dp_group
()
self
.
tp_group
=
get_tp_group
()
self
.
ep_group
=
get_ep_group
()
self
.
dp_rank
=
self
.
dp_group
.
rank_in_group
self
.
dp_world_size
=
self
.
dp_group
.
world_size
# all2all communication often has separate implementations for
# intra-node and inter-node communication
self
.
intranode
=
in_the_same_node_as
(
cpu_group
,
source_rank
=
0
)
self
.
internode
=
not
self
.
intranode
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
raise
NotImplementedError
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
def
destroy
(
self
):
pass
class
NaiveAll2All
(
All2AllBase
):
"""
A naive implementation of all2all communication.
It uses all-reduce under the hood, which is not
efficient at all. The main purpose is for testing and
debugging.
"""
def
__init__
(
self
,
cpu_group
,
model
):
super
().
__init__
(
cpu_group
,
model
)
def
naive_multicast
(
self
,
x
:
torch
.
Tensor
,
cu_tokens_across_dp_cpu
:
torch
.
Tensor
):
assert
(
len
(
x
.
shape
)
==
2
)
buffer
=
torch
.
empty
((
cu_tokens_across_dp_cpu
[
-
1
],
x
.
size
(
1
)),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
start
=
0
if
self
.
dp_rank
==
0
else
cu_tokens_across_dp_cpu
[
self
.
dp_rank
-
1
]
end
=
cu_tokens_across_dp_cpu
[
self
.
dp_rank
]
buffer
[
start
:
end
,
:].
copy_
(
x
)
for
idx
in
range
(
self
.
dp_world_size
):
start
=
0
if
idx
==
0
else
cu_tokens_across_dp_cpu
[
idx
-
1
]
end
=
cu_tokens_across_dp_cpu
[
idx
]
self
.
dp_group
.
broadcast
(
buffer
[
start
:
end
,
:],
idx
)
return
buffer
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
cu_tokens_across_dp_cpu
=
get_forward_context
(
).
dp_metadata
.
cu_tokens_across_dp_cpu
hidden_states
=
self
.
naive_multicast
(
hidden_states
,
cu_tokens_across_dp_cpu
)
router_logits
=
self
.
naive_multicast
(
router_logits
,
cu_tokens_across_dp_cpu
)
return
hidden_states
,
router_logits
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
cu_tokens_across_dp_cpu
=
get_forward_context
(
).
dp_metadata
.
cu_tokens_across_dp_cpu
start
=
0
if
self
.
dp_rank
==
0
else
cu_tokens_across_dp_cpu
[
self
.
dp_rank
-
1
]
end
=
cu_tokens_across_dp_cpu
[
self
.
dp_rank
]
all_hidden_states
=
self
.
dp_group
.
all_reduce
(
hidden_states
)
hidden_states
=
all_hidden_states
[
start
:
end
,
:]
return
hidden_states
def
destroy
(
self
):
pass
vllm/distributed/device_communicators/base_device_communicator.py
View file @
6266c57b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
...
@@ -149,3 +149,27 @@ class DeviceCommunicatorBase:
...
@@ -149,3 +149,27 @@ class DeviceCommunicatorBase:
def
destroy
(
self
):
def
destroy
(
self
):
pass
pass
def
prepare_communication_buffer_for_model
(
self
,
model
:
torch
.
nn
.
Module
)
->
None
:
"""
Prepare the communication buffer for the model.
This is a no-op in the base class.
"""
pass
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Dispatch the hidden states and router logits to the appropriate device.
This is a no-op in the base class.
"""
return
hidden_states
,
router_logits
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Combine the hidden states and router logits from the appropriate device.
This is a no-op in the base class.
"""
return
hidden_states
vllm/distributed/device_communicators/cuda_communicator.py
View file @
6266c57b
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
import
torch
import
torch
from
torch.distributed
import
ProcessGroup
from
torch.distributed
import
ProcessGroup
import
vllm.envs
as
envs
from
.all2all
import
All2AllBase
from
.base_device_communicator
import
DeviceCommunicatorBase
from
.base_device_communicator
import
DeviceCommunicatorBase
...
@@ -23,9 +26,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
...
@@ -23,9 +26,13 @@ class CudaCommunicator(DeviceCommunicatorBase):
from
vllm.distributed.parallel_state
import
(
from
vllm.distributed.parallel_state
import
(
_ENABLE_CUSTOM_ALL_REDUCE
)
_ENABLE_CUSTOM_ALL_REDUCE
)
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
use_pynccl
=
True
# ep does not use pynccl
use_pynccl
=
"ep"
not
in
unique_name
self
.
use_pynccl
=
use_pynccl
self
.
use_pynccl
=
use_pynccl
self
.
use_all2all
=
"ep"
in
unique_name
self
.
all2all_impl
:
Optional
[
All2AllBase
]
=
None
self
.
use_custom_allreduce
=
use_custom_allreduce
self
.
use_custom_allreduce
=
use_custom_allreduce
# lazy import to avoid documentation build error
# lazy import to avoid documentation build error
...
@@ -129,3 +136,31 @@ class CudaCommunicator(DeviceCommunicatorBase):
...
@@ -129,3 +136,31 @@ class CudaCommunicator(DeviceCommunicatorBase):
self
.
pynccl_comm
=
None
self
.
pynccl_comm
=
None
if
self
.
ca_comm
is
not
None
:
if
self
.
ca_comm
is
not
None
:
self
.
ca_comm
=
None
self
.
ca_comm
=
None
if
self
.
all2all_impl
is
not
None
:
self
.
all2all_impl
.
destroy
()
self
.
all2all_impl
=
None
def
prepare_communication_buffer_for_model
(
self
,
model
:
torch
.
nn
.
Module
)
->
None
:
"""
Prepare the communication buffer for the model.
"""
if
not
self
.
use_all2all
:
return
all2all_backend
=
envs
.
VLLM_ALL2ALL_BACKEND
if
all2all_backend
==
"naive"
:
from
.all2all
import
NaiveAll2All
self
.
all2all_impl
=
NaiveAll2All
(
self
.
cpu_group
,
model
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
self
.
all2all_impl
is
not
None
hidden_states
,
router_logits
=
self
.
all2all_impl
.
dispatch
(
hidden_states
,
router_logits
)
return
hidden_states
,
router_logits
def
combine
(
self
,
hidden_states
:
torch
.
Tensor
)
->
torch
.
Tensor
:
assert
self
.
all2all_impl
is
not
None
hidden_states
=
self
.
all2all_impl
.
combine
(
hidden_states
)
return
hidden_states
vllm/distributed/parallel_state.py
View file @
6266c57b
...
@@ -757,6 +757,22 @@ class GroupCoordinator:
...
@@ -757,6 +757,22 @@ class GroupCoordinator:
if
self
.
mq_broadcaster
is
not
None
:
if
self
.
mq_broadcaster
is
not
None
:
self
.
mq_broadcaster
=
None
self
.
mq_broadcaster
=
None
def
prepare_communication_buffer_for_model
(
self
,
model
:
torch
.
nn
.
Module
):
if
self
.
device_communicator
is
not
None
:
self
.
device_communicator
.
prepare_communication_buffer_for_model
(
model
)
def
dispatch
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
if
self
.
device_communicator
is
not
None
:
return
self
.
device_communicator
.
dispatch
(
hidden_states
,
router_logits
)
def
combine
(
self
,
hidden_states
)
->
torch
.
Tensor
:
if
self
.
device_communicator
is
not
None
:
return
self
.
device_communicator
.
combine
(
hidden_states
)
_WORLD
:
Optional
[
GroupCoordinator
]
=
None
_WORLD
:
Optional
[
GroupCoordinator
]
=
None
...
@@ -816,6 +832,14 @@ def get_dp_group() -> GroupCoordinator:
...
@@ -816,6 +832,14 @@ def get_dp_group() -> GroupCoordinator:
return
_DP
return
_DP
_EP
:
Optional
[
GroupCoordinator
]
=
None
def
get_ep_group
()
->
GroupCoordinator
:
assert
_EP
is
not
None
,
(
"expert parallel group is not initialized"
)
return
_EP
def
get_pp_group
()
->
GroupCoordinator
:
def
get_pp_group
()
->
GroupCoordinator
:
assert
_PP
is
not
None
,
(
assert
_PP
is
not
None
,
(
"pipeline model parallel group is not initialized"
)
"pipeline model parallel group is not initialized"
)
...
@@ -1001,10 +1025,21 @@ def initialize_model_parallel(
...
@@ -1001,10 +1025,21 @@ def initialize_model_parallel(
backend
,
backend
,
group_name
=
"dp"
)
group_name
=
"dp"
)
global
_EP
assert
_EP
is
None
,
(
"expert parallel group is already initialized"
)
group_ranks
=
all_ranks
.
transpose
(
1
,
2
).
reshape
(
-
1
,
data_parallel_size
*
tensor_model_parallel_size
).
unbind
(
0
)
group_ranks
=
[
x
.
tolist
()
for
x
in
group_ranks
]
_EP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
backend
,
group_name
=
"ep"
)
logger
.
info
(
logger
.
info
(
"rank %s in world size %s is assigned as "
"rank %s in world size %s is assigned as "
"DP rank %s, PP rank %s, TP rank %s"
,
rank
,
world_size
,
"DP rank %s, PP rank %s, TP rank %s, EP rank %s"
,
rank
,
world_size
,
_DP
.
rank_in_group
,
_PP
.
rank_in_group
,
_TP
.
rank_in_group
)
_DP
.
rank_in_group
,
_PP
.
rank_in_group
,
_TP
.
rank_in_group
,
_EP
.
rank_in_group
)
def
ensure_model_parallel_initialized
(
def
ensure_model_parallel_initialized
(
...
@@ -1035,6 +1070,23 @@ def ensure_model_parallel_initialized(
...
@@ -1035,6 +1070,23 @@ def ensure_model_parallel_initialized(
f
"
{
pipeline_model_parallel_size
=
}
"
)
f
"
{
pipeline_model_parallel_size
=
}
"
)
def
prepare_communication_buffer_for_model
(
model
:
torch
.
nn
.
Module
):
"""Prepare the communication buffer for the model.
Traditional communication libraries like NCCL are almost
model agnostic. However, emerging new communication libraries like
MoE all2all (DeepEP) usually allocate the communication buffer
based on the model shape for optimal performance.
"""
if
_TP
is
not
None
:
_TP
.
prepare_communication_buffer_for_model
(
model
)
if
_PP
is
not
None
:
_PP
.
prepare_communication_buffer_for_model
(
model
)
if
_DP
is
not
None
:
_DP
.
prepare_communication_buffer_for_model
(
model
)
if
_EP
is
not
None
:
_EP
.
prepare_communication_buffer_for_model
(
model
)
def
model_parallel_is_initialized
():
def
model_parallel_is_initialized
():
"""Check if tensor and pipeline parallel groups are initialized."""
"""Check if tensor and pipeline parallel groups are initialized."""
return
(
_TP
is
not
None
and
_PP
is
not
None
)
return
(
_TP
is
not
None
and
_PP
is
not
None
)
...
@@ -1095,6 +1147,11 @@ def destroy_model_parallel():
...
@@ -1095,6 +1147,11 @@ def destroy_model_parallel():
_DP
.
destroy
()
_DP
.
destroy
()
_DP
=
None
_DP
=
None
global
_EP
if
_EP
:
_EP
.
destroy
()
_EP
=
None
def
destroy_distributed_environment
():
def
destroy_distributed_environment
():
global
_WORLD
global
_WORLD
...
...
vllm/distributed/utils.py
View file @
6266c57b
...
@@ -362,6 +362,11 @@ def stateless_destroy_torch_distributed_process_group(
...
@@ -362,6 +362,11 @@ def stateless_destroy_torch_distributed_process_group(
stateless_init_torch_distributed_process_group().
stateless_init_torch_distributed_process_group().
"""
"""
# Lazy import for non-CUDA backends.
# Lazy import for non-CUDA backends.
from
torch.distributed.distributed_c10d
import
_shutdown_backend
try
:
_shutdown_backend
(
pg
)
# pytorch <= 2.6
from
torch.distributed.distributed_c10d
import
_shutdown_backend
_shutdown_backend
(
pg
)
except
ImportError
:
# pytorch >= 2.7
pg
.
shutdown
()
_unregister_process_group
(
pg
.
group_name
)
_unregister_process_group
(
pg
.
group_name
)
vllm/envs.py
View file @
6266c57b
...
@@ -115,6 +115,7 @@ if TYPE_CHECKING:
...
@@ -115,6 +115,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_INSECURE_SERIALIZATION
:
bool
=
False
VLLM_ALLOW_INSECURE_SERIALIZATION
:
bool
=
False
VLLM_NIXL_SIDE_CHANNEL_HOST
:
str
=
"localhost"
VLLM_NIXL_SIDE_CHANNEL_HOST
:
str
=
"localhost"
VLLM_NIXL_SIDE_CHANNEL_PORT
:
int
=
5557
VLLM_NIXL_SIDE_CHANNEL_PORT
:
int
=
5557
VLLM_ALL2ALL_BACKEND
:
str
=
"naive"
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -764,6 +765,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
...
@@ -764,6 +765,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Port used for NIXL handshake between remote agents.
# Port used for NIXL handshake between remote agents.
"VLLM_NIXL_SIDE_CHANNEL_PORT"
:
"VLLM_NIXL_SIDE_CHANNEL_PORT"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_NIXL_SIDE_CHANNEL_PORT"
,
"5557"
)),
lambda
:
int
(
os
.
getenv
(
"VLLM_NIXL_SIDE_CHANNEL_PORT"
,
"5557"
)),
# all2all backend for vllm's expert parallel communication
"VLLM_ALL2ALL_BACKEND"
:
lambda
:
os
.
getenv
(
"VLLM_ALL2ALL_BACKEND"
,
"naive"
),
}
}
# end-env-vars-definition
# end-env-vars-definition
...
...
vllm/model_executor/layers/fused_moe/layer.py
View file @
6266c57b
...
@@ -10,7 +10,8 @@ from torch.nn.parameter import UninitializedParameter
...
@@ -10,7 +10,8 @@ from torch.nn.parameter import UninitializedParameter
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
get_current_vllm_config
from
vllm.config
import
get_current_vllm_config
from
vllm.distributed
import
(
get_dp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_dp_group
,
get_ep_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
...
@@ -832,24 +833,6 @@ class FusedMoE(torch.nn.Module):
...
@@ -832,24 +833,6 @@ class FusedMoE(torch.nn.Module):
return
topk_weights
,
topk_ids
return
topk_weights
,
topk_ids
def
naive_multicast
(
self
,
x
:
torch
.
Tensor
,
cu_tokens_across_dp_cpu
:
torch
.
Tensor
):
assert
(
len
(
x
.
shape
)
==
2
)
buffer
=
torch
.
empty
((
cu_tokens_across_dp_cpu
[
-
1
],
x
.
size
(
1
)),
device
=
x
.
device
,
dtype
=
x
.
dtype
)
start
=
0
if
self
.
dp_rank
==
0
else
cu_tokens_across_dp_cpu
[
self
.
dp_rank
-
1
]
end
=
cu_tokens_across_dp_cpu
[
self
.
dp_rank
]
buffer
[
start
:
end
,
:].
copy_
(
x
)
for
idx
in
range
(
get_dp_group
().
world_size
):
start
=
0
if
idx
==
0
else
cu_tokens_across_dp_cpu
[
idx
-
1
]
end
=
cu_tokens_across_dp_cpu
[
idx
]
get_dp_group
().
broadcast
(
buffer
[
start
:
end
,
:],
idx
)
return
buffer
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
router_logits
:
torch
.
Tensor
):
router_logits
:
torch
.
Tensor
):
if
self
.
use_direct_call
:
if
self
.
use_direct_call
:
...
@@ -863,14 +846,8 @@ class FusedMoE(torch.nn.Module):
...
@@ -863,14 +846,8 @@ class FusedMoE(torch.nn.Module):
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
if
self
.
dp_size
>
1
:
if
self
.
dp_size
>
1
:
cu_tokens_across_dp_cpu
=
get_forward_context
(
hidden_states
,
router_logits
=
get_ep_group
().
dispatch
(
).
dp_metadata
.
cu_tokens_across_dp_cpu
hidden_states
,
router_logits
)
hidden_states
=
self
.
naive_multicast
(
hidden_states
,
cu_tokens_across_dp_cpu
)
router_logits
=
self
.
naive_multicast
(
router_logits
,
cu_tokens_across_dp_cpu
)
# Matrix multiply.
# Matrix multiply.
final_hidden_states
=
self
.
quant_method
.
apply
(
final_hidden_states
=
self
.
quant_method
.
apply
(
layer
=
self
,
layer
=
self
,
...
@@ -891,12 +868,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -891,12 +868,7 @@ class FusedMoE(torch.nn.Module):
)
)
if
self
.
dp_size
>
1
:
if
self
.
dp_size
>
1
:
start
=
0
if
self
.
dp_rank
==
0
else
cu_tokens_across_dp_cpu
[
final_hidden_states
=
get_ep_group
().
combine
(
final_hidden_states
)
self
.
dp_rank
-
1
]
end
=
cu_tokens_across_dp_cpu
[
self
.
dp_rank
]
all_hidden_states
=
get_dp_group
().
all_reduce
(
final_hidden_states
)
final_hidden_states
=
all_hidden_states
[
start
:
end
,
:]
if
self
.
reduce_results
and
(
self
.
tp_size
>
1
or
self
.
ep_size
>
1
):
if
self
.
reduce_results
and
(
self
.
tp_size
>
1
or
self
.
ep_size
>
1
):
# Default set to False. (May have to add shared expert outputs.)
# Default set to False. (May have to add shared expert outputs.)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
6266c57b
...
@@ -19,7 +19,8 @@ from vllm.config import (CompilationLevel, VllmConfig,
...
@@ -19,7 +19,8 @@ from vllm.config import (CompilationLevel, VllmConfig,
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
has_kv_transfer_group
)
has_kv_transfer_group
)
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorBase_V1
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorBase_V1
from
vllm.distributed.parallel_state
import
get_pp_group
,
graph_capture
from
vllm.distributed.parallel_state
import
(
get_pp_group
,
graph_capture
,
prepare_communication_buffer_for_model
)
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
...
@@ -1457,6 +1458,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1457,6 +1458,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logger
.
info
(
"Model loading took %.4f GiB and %.6f seconds"
,
logger
.
info
(
"Model loading took %.4f GiB and %.6f seconds"
,
self
.
model_memory_usage
/
GiB_bytes
,
self
.
model_memory_usage
/
GiB_bytes
,
time_after_load
-
time_before_load
)
time_after_load
-
time_before_load
)
prepare_communication_buffer_for_model
(
self
.
model
)
def
_get_prompt_logprobs_dict
(
def
_get_prompt_logprobs_dict
(
self
,
self
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment