Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b87c21fc
Unverified
Commit
b87c21fc
authored
Mar 03, 2025
by
Mengqing Cao
Committed by
GitHub
Mar 03, 2025
Browse files
[Misc][Platform] Move use allgather to platform (#14010)
Signed-off-by:
Mengqing Cao
<
cmq0113@163.com
>
parent
e584b85a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
24 additions
and
7 deletions
+24
-7
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+3
-7
vllm/platforms/interface.py
vllm/platforms/interface.py
+13
-0
vllm/platforms/neuron.py
vllm/platforms/neuron.py
+4
-0
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+4
-0
No files found.
vllm/model_executor/layers/logits_processor.py
View file @
b87c21fc
...
...
@@ -8,7 +8,6 @@ import torch
import
torch.nn
as
nn
import
vllm.envs
as
envs
from
vllm.config
import
get_current_vllm_config
from
vllm.distributed
import
(
tensor_model_parallel_all_gather
,
tensor_model_parallel_gather
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -51,11 +50,7 @@ class LogitsProcessor(nn.Module):
# Soft cap the logits. Used in Gemma 2.
self
.
soft_cap
=
soft_cap
# Whether to use gather or all-gather to gather the logits.
parallel_config
=
get_current_vllm_config
().
parallel_config
self
.
use_all_gather
=
current_platform
.
is_tpu
()
\
or
current_platform
.
is_neuron
()
\
or
envs
.
VLLM_USE_V1
\
or
parallel_config
.
distributed_executor_backend
==
"external_launcher"
# noqa
self
.
use_all_gather
=
current_platform
.
use_all_gather
()
def
forward
(
self
,
...
...
@@ -83,7 +78,8 @@ class LogitsProcessor(nn.Module):
logits
*=
self
.
scale
# Apply logits processors (if any).
if
sampling_metadata
is
not
None
:
if
sampling_metadata
is
not
None
and
\
sampling_metadata
.
seq_groups
is
not
None
:
logits
=
_apply_logits_processors
(
logits
,
sampling_metadata
)
return
logits
...
...
vllm/platforms/interface.py
View file @
b87c21fc
...
...
@@ -330,6 +330,19 @@ class Platform:
"""
return
"vllm.distributed.device_communicators.base_device_communicator.DeviceCommunicatorBase"
# noqa
@
classmethod
def
use_all_gather
(
cls
)
->
bool
:
"""
Whether to use allgather in LogitsProcessor to gather the logits.
"""
import
vllm.envs
as
envs
from
vllm.config
import
get_current_vllm_config
parallel_config
=
get_current_vllm_config
().
parallel_config
return
(
envs
.
VLLM_USE_V1
or
parallel_config
.
distributed_executor_backend
==
"external_launcher"
)
class
UnspecifiedPlatform
(
Platform
):
_enum
=
PlatformEnum
.
UNSPECIFIED
...
...
vllm/platforms/neuron.py
View file @
b87c21fc
...
...
@@ -55,3 +55,7 @@ class NeuronPlatform(Platform):
def
is_pin_memory_available
(
cls
)
->
bool
:
logger
.
warning
(
"Pin memory is not supported on Neuron."
)
return
False
@
classmethod
def
use_all_gather
(
cls
)
->
bool
:
return
True
vllm/platforms/tpu.py
View file @
b87c21fc
...
...
@@ -119,3 +119,7 @@ class TpuPlatform(Platform):
@
classmethod
def
get_device_communicator_cls
(
cls
)
->
str
:
return
"vllm.distributed.device_communicators.tpu_communicator.TpuCommunicator"
# noqa
@
classmethod
def
use_all_gather
(
cls
)
->
bool
:
return
True
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment